hexagon: support for OP_CPY, host buffers now optional, hvx-utils refactoring and optimizations (#18822)
* hexagon: disable repack buffers if host buffers are disabled, improved handling of env vars * hexagon: add support for OP_CPY fp16/fp32 -> fp16/fp32 Factore out all hvx_copy functions into hvx-copy.h header and reduced code duplication. Update HTP ops infra to support OP_CPY * hexagon: cleanup and refactor hex/hvx/htp headers and helper libs hex is basically all scalar/core platform stuff (L2, DMA, basic utils) hvx is all hvx related utils, helpers, etc htp is higher level stuff like Ops, etc hvx-utils library got a nice round of cleanup and refactoring to reduce duplication use hvx_vec_store_a where possible * hexagon: refactor HVX sigmoid functions to hvx-sigmoid.h Moved sigmoid and tanh vector functions from hvx-utils.h to a new header hvx-sigmoid.h. Implemented aligned and unaligned variants for sigmoid array processing using a macro pattern similar to hvx-copy.h. Updated act-ops.c to use the new aligned variant hvx_sigmoid_f32_aa. Removed unused hvx-sigmoid.c. * hexagon: factor out hvx-sqrt.h * hexagon: mintor update to hvx-utils.h * hexagon: remove spurios log * hexagon: factor out and optimize hvx_add/sub/mul * hexagon: remove _opt variants of add/sub/mul as they simply fully aligned versions * hexagon: refactor reduction functions to hvx-reduce.h Moved `hvx_self_max_f32` and `hvx_self_sum_f32` from `hvx-utils.h`/`.c` to `hvx-reduce.h`. Renamed them to `hvx_reduce_max_f32` and `hvx_reduce_sum_f32`. Added aligned (`_a`) and unaligned (`_u`) variants and used macros to unify logic. Updated `softmax-ops.c` to use the new functions. * hexagon: refactor the rest of arithmetic functions to hvx-arith.h Moved `hvx_sum_of_squares_f32`, `hvx_min_scalar_f32`, and `hvx_clamp_scalar_f32` from `hvx-utils.c/h` to `hvx-arith.h`. Implemented aligned/unaligned variants (`_aa`, `_au`, etc.) and used macros to reduce code duplication. Updated `hvx_min_scalar_f32` and `hvx_clamp_scalar_f32` to use `dst, src, ..., n` argument order. Updated call sites in `act-ops.c`. Refactor Hexagon HVX arithmetic functions (min, clamp) to hvx-arith.h Moved `hvx_min_scalar_f32` and `hvx_clamp_scalar_f32` from `hvx-utils.c/h` to `hvx-arith.h`. Implemented aligned/unaligned variants (`_aa`, `_au`, etc.) and used macros to reduce code duplication. Updated these functions to use `dst, src, ..., n` argument order and updated call sites in `act-ops.c`. `hvx_sum_of_squares_f32` remains in `hvx-utils.c` as requested. * hexagon: refactor hvx_sum_of_squares_f32 - Modify `hvx_sum_of_squares_f32` in `ggml/src/ggml-hexagon/htp/hvx-reduce.h` to use `dst, src` signature. - Implement `_a` (aligned) and `_u` (unaligned) variants for `hvx_sum_of_squares_f32`. - Update `hvx_reduce_loop_body` macro to support both returning and storing results via `finalize_op`. - Update existing reduction functions in `hvx-reduce.h` to use the updated macro. - Update `rms_norm_htp_f32` in `ggml/src/ggml-hexagon/htp/unary-ops.c` to match the new signature. * hexagon: use hvx_splat instead of memset * hexagon: consistent use of f32/f16 in all function names to match the rest of GGML * hexagon: fix hvx_copy_f16_f32 on v75 and older * hexagon: update readme to include GGML_HEXAGON_EXPERIMENTAL * scripts: update snapdragon/adb scripts to enable host param
This commit is contained in:
parent
36f0132464
commit
cff777f226
|
|
@ -210,6 +210,10 @@ build: 6a8cf8914 (6733)
|
|||
Controls whether the Hexagon backend allocates host buffers. By default, all buffers except for REPACK are host buffers.
|
||||
This option is required for testing Ops that require REPACK buffers (MUL_MAT and MUL_MAT_ID).
|
||||
|
||||
- `GGML_HEXAGON_EXPERIMENTAL=1`
|
||||
Controls whether the Hexagon backend enables experimental features.
|
||||
This option is required for enabling/testing experimental Ops (FLASH_ATTN_EXT).
|
||||
|
||||
- `GGML_HEXAGON_VERBOSE=1`
|
||||
Enables verbose logging of Ops from the backend. Example output:
|
||||
|
||||
|
|
|
|||
|
|
@ -42,12 +42,12 @@
|
|||
#include "htp_iface.h"
|
||||
|
||||
static size_t opt_ndev = 1;
|
||||
static size_t opt_nhvx = 0; // use all
|
||||
static int opt_arch = 0; // autodetect
|
||||
static size_t opt_nhvx = 0; // use all
|
||||
static int opt_arch = 0; // autodetect
|
||||
static int opt_etm = 0;
|
||||
static int opt_verbose = 0;
|
||||
static int opt_profile = 0;
|
||||
static int opt_hostbuf = 1;
|
||||
static int opt_hostbuf = 1; // hostbuf ON by default
|
||||
static int opt_experimental = 0;
|
||||
|
||||
// Enable all stages by default
|
||||
|
|
@ -1753,6 +1753,9 @@ static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b)
|
|||
}
|
||||
|
||||
static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) {
|
||||
if (!opt_hostbuf) {
|
||||
return ggml_backend_buffer_is_hexagon(b);
|
||||
}
|
||||
return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
|
||||
}
|
||||
|
||||
|
|
@ -2302,6 +2305,16 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_CPY;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_GET_ROWS;
|
||||
|
||||
|
|
@ -2557,6 +2570,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_CPY:
|
||||
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
|
|
@ -2858,6 +2875,27 @@ static bool ggml_hexagon_supported_buffers(ggml_hexagon_session *sess, const str
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
// for now we can do f32 -> f16 and f16 -> f32 (without reshaping)
|
||||
if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
|
||||
if ( dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) return false;
|
||||
|
||||
const bool sametype = (src0->type == dst->type);
|
||||
const bool transposed = ggml_is_transposed(src0) || ggml_is_transposed(dst);
|
||||
const bool sameshape = !transposed && ggml_are_same_shape(src0, dst);
|
||||
|
||||
// can handle any shape and any same-type (pretty slow if reshaping is required)
|
||||
if (sametype) return true;
|
||||
|
||||
// cannot handle re-shaping and type conversion at the same time
|
||||
if (!sameshape) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(dev->context);
|
||||
|
||||
|
|
@ -2936,6 +2974,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
supp = ggml_hexagon_supported_get_rows(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_CPY:
|
||||
supp = ggml_hexagon_supported_cpy(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -3061,7 +3103,7 @@ static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t
|
|||
}
|
||||
|
||||
static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
|
||||
if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0 && opt_hostbuf) {
|
||||
ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type;
|
||||
return (void *) fct;
|
||||
}
|
||||
|
|
@ -3078,34 +3120,31 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
|||
static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
|
||||
const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
|
||||
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
|
||||
const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
|
||||
const char * str_opmask = getenv("GGML_HEXAGON_OPMASK");
|
||||
const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC");
|
||||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
|
||||
opt_experimental = str_experimental ? atoi(str_experimental) : 0;
|
||||
opt_verbose = str_verbose ? atoi(str_verbose) : 0;
|
||||
opt_profile = getenv("GGML_HEXAGON_PROFILE") != nullptr;
|
||||
opt_etm = getenv("GGML_HEXAGON_ETM") != nullptr;
|
||||
opt_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL") != nullptr;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask;
|
||||
opt_opsync = str_opsync ? atoi(str_opsync) : 0;
|
||||
opt_profile = str_profile ? atoi(str_profile) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
|
||||
|
||||
const char * str_opmask = getenv("GGML_HEXAGON_OPMASK");
|
||||
if (str_opmask != nullptr) {
|
||||
opt_opmask = strtoul(str_opmask, NULL, 0);
|
||||
}
|
||||
opt_opsync = getenv("GGML_HEXAGON_OPSYNC") != nullptr;
|
||||
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
if (str_ndev) {
|
||||
opt_ndev = strtoul(str_ndev, NULL, 0);
|
||||
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
|
||||
opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
|
||||
}
|
||||
if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
|
||||
opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
|
||||
}
|
||||
|
||||
const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
|
||||
if (str_nhvx) {
|
||||
opt_nhvx = strtoul(str_nhvx, NULL, 0);
|
||||
}
|
||||
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
if (str_arch) {
|
||||
if (str_arch[0] == 'v') {
|
||||
str_arch++;
|
||||
|
|
@ -3113,8 +3152,6 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
|||
opt_arch = strtoul(str_arch, NULL, 0);
|
||||
}
|
||||
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
|
||||
|
||||
reg->context = new ggml_hexagon_registry(reg);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
|
||||
|
|
|
|||
|
|
@ -17,11 +17,7 @@ add_library(${HTP_LIB} SHARED
|
|||
main.c
|
||||
htp_iface_skel.c
|
||||
worker-pool.c
|
||||
htp-dma.c
|
||||
hvx-sigmoid.c
|
||||
hvx-inverse.c
|
||||
hvx-exp.c
|
||||
hvx-utils.c
|
||||
hex-dma.c
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
|
|
@ -31,10 +27,12 @@ add_library(${HTP_LIB} SHARED
|
|||
flash-attn-ops.c
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
|
|
|||
|
|
@ -2,27 +2,20 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define htp_act_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
|
|
@ -76,7 +69,7 @@
|
|||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
||||
static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
|
|
@ -124,9 +117,9 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
|
||||
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
|
|
@ -175,9 +168,9 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
//swiglu(x) = x1 * sigmoid(x0)
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
|
||||
hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
|
||||
(const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, nc);
|
||||
hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
|
||||
(const uint8_t *) src1_spad_ptr, nc);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
|
||||
|
|
@ -203,7 +196,7 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
||||
static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
|
|
@ -249,9 +242,9 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
|||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
|
||||
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
|
|
@ -304,18 +297,18 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
|||
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// x (src0_spad_data) = std::min(src0_p[k], limit);
|
||||
hvx_min_scalar_f32((const uint8_t *) src0_spad_ptr, limit, (uint8_t *) src0_spad_ptr, nc);
|
||||
hvx_min_scalar_f32((uint8_t *) src0_spad_ptr, (const uint8_t *) src0_spad_ptr, limit, nc);
|
||||
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
|
||||
hvx_clamp_scalar_f32((const uint8_t *) src1_spad_ptr, -limit, limit, (uint8_t *) src1_spad_ptr, nc);
|
||||
hvx_clamp_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, -limit, limit, nc);
|
||||
// y (src1_spad_data) = y1 + 1.f
|
||||
hvx_add_scalar_f32((const uint8_t *) src1_spad_ptr, 1.0, (uint8_t *) src1_spad_ptr, nc);
|
||||
hvx_add_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, 1.0, nc);
|
||||
// x1 (dst_spad_data) = alpha * (x)
|
||||
hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, alpha, (uint8_t *) dst_spad_ptr, nc);
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, alpha, nc);
|
||||
// x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1))
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc);
|
||||
// out = x * sigmoid(alpha * x) * (y + 1.f)
|
||||
hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
|
||||
(const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
|
||||
hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
|
||||
(const uint8_t *) src1_spad_ptr, nc);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
|
||||
|
|
@ -342,7 +335,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
|||
}
|
||||
|
||||
|
||||
static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
|
||||
static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
|
|
@ -358,8 +351,8 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
|
|
@ -415,9 +408,9 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// gelu = x * sigmoid(1.702 * x) // current implementation
|
||||
hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, (float) 1.702, (uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
|
|
@ -442,15 +435,15 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
||||
static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
|
|
@ -466,8 +459,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
|
|
@ -522,8 +515,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// silu = x * sigmoid(x)
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
|
||||
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
|
|
@ -548,25 +541,25 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static int execute_op_activations_fp32(struct htp_ops_context * octx) {
|
||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
|
|
@ -583,21 +576,21 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
|
|||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_UNARY_SILU:
|
||||
act_op_func = unary_silu_fp32;
|
||||
act_op_func = unary_silu_f32;
|
||||
op_type = "silu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
act_op_func = glu_swiglu_fp32;
|
||||
act_op_func = glu_swiglu_f32;
|
||||
op_type = "swiglu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU_OAI:
|
||||
act_op_func = glu_swiglu_oai_fp32;
|
||||
act_op_func = glu_swiglu_oai_f32;
|
||||
op_type = "swiglu-oai-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_GELU:
|
||||
act_op_func = unary_gelu_fp32;
|
||||
act_op_func = unary_gelu_f32;
|
||||
op_type = "gelu-f32";
|
||||
break;
|
||||
default:
|
||||
|
|
@ -617,9 +610,9 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
|
|||
src1_row_size = src0_row_size;
|
||||
}
|
||||
|
||||
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
|
||||
|
|
@ -670,7 +663,7 @@ int op_activations(struct htp_ops_context * octx) {
|
|||
|
||||
switch (octx->src0.type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_activations_fp32(octx);
|
||||
err = execute_op_activations_f32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -2,36 +2,25 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0,
|
||||
const uint8_t * src1,
|
||||
uint8_t * data_dst,
|
||||
const int num_elems);
|
||||
typedef void (*hvx_elemwise_f32_func)(uint8_t * data_dst, const uint8_t * src0, const uint8_t * src1, const uint32_t num_elems);
|
||||
|
||||
static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
|
||||
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
|
||||
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f32_aa, hvx_sub_f32_aa };
|
||||
|
||||
#define htp_binary_preamble \
|
||||
const struct htp_tensor * src0 = &octx->src0; \
|
||||
|
|
@ -98,9 +87,8 @@ static void binary_job_f32_per_thread(struct htp_ops_context * octx,
|
|||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == hex_is_aligned((void *) dst->data, VLEN))) {
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
|
|
@ -130,24 +118,24 @@ static void binary_job_f32_per_thread(struct htp_ops_context * octx,
|
|||
const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
|
||||
hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1);
|
||||
if (src1_row_size == src0_row_size) {
|
||||
htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size);
|
||||
hex_l2fetch(src1_ptr, src1_row_size, src1_row_size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t nr0 = ne00 / ne10;
|
||||
if (nr0 > 1) {
|
||||
if ((1 == is_aligned) && (nr0 == ne00)) {
|
||||
hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
|
||||
hvx_splat_f32_a(spad_data_th, *(float *) src1_ptr, nr0);
|
||||
} else {
|
||||
for (uint32_t r = 0; r < nr0; r++) {
|
||||
memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
|
||||
}
|
||||
}
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00);
|
||||
func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, ne00);
|
||||
} else {
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
|
||||
func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00);
|
||||
}
|
||||
|
||||
src0_ptr += src0_row_size;
|
||||
|
|
@ -185,11 +173,6 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
|
|||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n");
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
|
@ -210,9 +193,9 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
|
|||
const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
|
||||
hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1);
|
||||
if (src1_row_size == src0_row_size) {
|
||||
htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size);
|
||||
hex_l2fetch(src1_ptr + ne10, src1_row_size, src1_row_size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -221,9 +204,9 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
|
|||
for (uint32_t r = 0; r < nr0; r++) {
|
||||
memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
|
||||
}
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00);
|
||||
func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data, ne00);
|
||||
} else {
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
|
||||
func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -299,9 +282,9 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
|
|||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,251 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
struct htp_copy_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
uint32_t src0_type_size;
|
||||
uint32_t src0_block_size;
|
||||
|
||||
uint32_t dst_type_size;
|
||||
uint32_t dst_block_size;
|
||||
|
||||
uint32_t src0_blocks_per_row;
|
||||
uint32_t dst_blocks_per_row;
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
|
||||
void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);
|
||||
};
|
||||
|
||||
#define cpy_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0; \
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
cpy_preamble;
|
||||
|
||||
// parallelize by src0 rows
|
||||
const uint32_t dr = ct->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
|
||||
|
||||
// copy by rows
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) {
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i01 = ir0; i01 < ir1; i01++) {
|
||||
uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2);
|
||||
hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) {
|
||||
cpy_preamble;
|
||||
|
||||
// parallelize by src0 rows
|
||||
const uint32_t dr = ct->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
|
||||
|
||||
// dst counters
|
||||
int64_t k10 = 0;
|
||||
int64_t i11 = 0;
|
||||
int64_t i12 = 0;
|
||||
int64_t i13 = 0;
|
||||
|
||||
// number of blocks in a row
|
||||
const int64_t nk00 = ct->src0_blocks_per_row;
|
||||
const int64_t nk0 = ct->dst_blocks_per_row;
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
k10 += nk00 * ir0;
|
||||
while (k10 >= nk0) {
|
||||
k10 -= nk0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
||||
for (int64_t k00 = 0; k00 < nk00; k00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
memcpy(dst_ptr, src0_ptr, ct->dst_type_size);
|
||||
|
||||
if (++k10 == nk0) {
|
||||
k10 = 0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
k10 += nk00 * (ne01 - ir1);
|
||||
while (k10 >= nk0) {
|
||||
k10 -= nk0;
|
||||
if (++i11 == ne1) {
|
||||
i11 = 0;
|
||||
if (++i12 == ne2) {
|
||||
i12 = 0;
|
||||
if (++i13 == ne3) {
|
||||
i13 = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
cpy_preamble;
|
||||
|
||||
// parallelize by src0 rows
|
||||
const uint32_t dr = ct->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
|
||||
|
||||
// copy by rows
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) {
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i01 = ir0; i01 < ir1; i01++) {
|
||||
uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2);
|
||||
hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
cpy_preamble;
|
||||
|
||||
// parallelize by src0 rows
|
||||
const uint32_t dr = ct->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
|
||||
|
||||
// copy by rows
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) {
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i01 = ir0; i01 < ir1; i01++) {
|
||||
uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2);
|
||||
hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
|
||||
struct htp_copy_context *ct = (struct htp_copy_context *) data;
|
||||
ct->copy(ct, ct->octx, n, i);
|
||||
}
|
||||
|
||||
int op_cpy(struct htp_ops_context * octx) {
|
||||
cpy_preamble;
|
||||
|
||||
struct htp_copy_context ct;
|
||||
ct.octx = octx;
|
||||
|
||||
switch (src0->type) {
|
||||
case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
|
||||
case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
|
||||
default:
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
switch (dst->type) {
|
||||
case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
|
||||
case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
|
||||
default:
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const bool sametype = (src0->type == dst->type);
|
||||
const bool transposed = (nb00 > nb01) || (nb0 > nb1);
|
||||
const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
if (sametype && sameshape) {
|
||||
ct.copy = cpy_thread_sametype_sameshape;
|
||||
} else if (sameshape) {
|
||||
/**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32)
|
||||
ct.copy = cpy_thread_f16_f32_sameshape;
|
||||
else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16)
|
||||
ct.copy = cpy_thread_f32_f16_sameshape;
|
||||
else
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
} else if (sametype) {
|
||||
ct.copy = cpy_thread_sametype_reshape;
|
||||
} else {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -2,25 +2,20 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
|
||||
|
|
@ -70,8 +65,8 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
|
|||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
|
@ -111,8 +106,8 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
|||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
|
|
@ -124,7 +119,7 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict
|
|||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S = hvx_vec_splat_fp16(s);
|
||||
HVX_Vector S = hvx_vec_splat_f16(s);
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
|
|
@ -148,7 +143,7 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict
|
|||
|
||||
if (nloe) {
|
||||
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
|
||||
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -225,18 +220,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
const uint32_t DV = nev0;
|
||||
|
||||
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
|
||||
const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
|
||||
const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
|
||||
|
||||
const size_t size_k_row = DK * sizeof(__fp16);
|
||||
const size_t size_v_row = DV * sizeof(__fp16);
|
||||
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
|
||||
|
||||
const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
|
||||
const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
|
||||
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
|
||||
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
|
||||
|
||||
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
|
||||
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
|
||||
|
|
@ -272,8 +267,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
// Clear accumulator
|
||||
hvx_splat_f32_a(spad_a, 0, DV);
|
||||
float * VKQ32 = (float *) spad_a;
|
||||
memset(VKQ32, 0, DV * sizeof(float));
|
||||
|
||||
const __fp16 * mp_base = NULL;
|
||||
if (mask) {
|
||||
|
|
@ -340,30 +335,30 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
|
||||
// 2. Softcap
|
||||
if (logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_fp32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
|
||||
scores = hvx_vec_tanh_f32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 3. Mask
|
||||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
|
||||
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
|
||||
|
||||
HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
|
||||
HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
|
||||
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
|
||||
|
||||
HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
|
||||
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
|
||||
|
||||
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
|
||||
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
|
||||
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 4. Online Softmax Update
|
||||
HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
|
||||
float m_block = hvx_vec_get_fp32(v_max);
|
||||
HVX_Vector v_max = hvx_vec_reduce_max_f32(scores);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
|
|
@ -374,12 +369,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
S = S * ms;
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
|
||||
float p_sum = hvx_vec_get_fp32(p_sum_vec);
|
||||
HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P);
|
||||
float p_sum = hvx_vec_get_f32(p_sum_vec);
|
||||
S += p_sum;
|
||||
|
||||
// 5. Accumulate V
|
||||
|
|
@ -484,9 +479,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
|
||||
|
||||
if (dst->type == HTP_TYPE_F32) {
|
||||
hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
} else if (dst->type == HTP_TYPE_F16) {
|
||||
hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -523,16 +518,16 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
|||
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
|
||||
size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
|
||||
size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_q_block = size_q_row_padded * 1; // single row for now
|
||||
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
||||
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
||||
|
||||
octx->src0_spad.size_per_thread = size_q_block * 1;
|
||||
octx->src1_spad.size_per_thread = size_k_block * 2;
|
||||
|
|
|
|||
|
|
@ -2,14 +2,9 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
|
|
@ -19,7 +14,6 @@
|
|||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define get_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
|
|
@ -72,7 +66,7 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
|||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#include "htp-dma.h"
|
||||
#include "hex-dma.h"
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
#define HTP_DMA_H
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
#ifndef HEX_DUMP_H
|
||||
#define HEX_DUMP_H
|
||||
|
||||
#include <HAP_farf.h>
|
||||
|
||||
static inline void hex_dump_int8_line(char * pref, const int8_t * x, int n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n && p < p_end; i++) {
|
||||
p += snprintf(p, p_end - p, "%d, ", x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n && p < p_end; i++) {
|
||||
p += snprintf(p, p_end - p, "%d, ", x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += snprintf(p, p_end - p, "%d, ", (int) x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void hex_dump_f16_line(char * pref, const __fp16 * x, uint32_t n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void hex_dump_f32_line(char * pref, const float * x, uint32_t n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += snprintf(p, p_end - p, "%.6f, ", x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void hex_dump_f32(char * pref, const float * x, uint32_t n) {
|
||||
uint32_t n0 = n / 16;
|
||||
uint32_t n1 = n % 16;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < n0; i++) {
|
||||
hex_dump_f32_line(pref, x + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
hex_dump_f32_line(pref, x + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hex_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
|
||||
uint32_t n0 = n / 16;
|
||||
uint32_t n1 = n % 16;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < n0; i++) {
|
||||
hex_dump_f16_line(pref, x + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
hex_dump_f16_line(pref, x + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HEX_DUMP_H */
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
#ifndef HEX_FASTDIV_H
|
||||
#define HEX_FASTDIV_H
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
// and a shift:
|
||||
//
|
||||
// n/d = (mulhi(n, mp) + n) >> L;
|
||||
struct fastdiv_values {
|
||||
uint32_t mp;
|
||||
uint32_t l;
|
||||
};
|
||||
|
||||
static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
|
||||
struct fastdiv_values result = { 0, 0 };
|
||||
// compute L = ceil(log2(d));
|
||||
while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
|
||||
++(result.l);
|
||||
}
|
||||
|
||||
result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
|
||||
// Compute high 32 bits of n * mp
|
||||
const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp)
|
||||
// add n, apply bit shift
|
||||
return (hi + n) >> vals->l;
|
||||
}
|
||||
|
||||
static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
|
||||
return n - fastdiv(n, vals) * d;
|
||||
}
|
||||
|
||||
#endif /* HEX_FASTDIV_H */
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
#ifndef HEX_UTILS_H
|
||||
#define HEX_UTILS_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hexagon_types.h"
|
||||
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-dump.h"
|
||||
|
||||
#ifndef MAX
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
static inline uint64_t hex_get_cycles() {
|
||||
uint64_t cycles = 0;
|
||||
asm volatile(" %0 = c15:14\n" : "=r"(cycles));
|
||||
return cycles;
|
||||
}
|
||||
|
||||
static inline uint64_t hex_get_pktcnt() {
|
||||
uint64_t pktcnt;
|
||||
asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
|
||||
return pktcnt;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
}
|
||||
|
||||
#endif /* HEX_UTILS_H */
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
#ifndef HTP_CTX_H
|
||||
#define HTP_CTX_H
|
||||
|
||||
#include "htp-dma.h"
|
||||
#include "hex-dma.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ enum htp_op {
|
|||
HTP_OP_SET_ROWS = 15,
|
||||
HTP_OP_SCALE = 16,
|
||||
HTP_OP_GET_ROWS = 17,
|
||||
HTP_OP_CPY = 18,
|
||||
INVALID
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@
|
|||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "worker-pool.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <hex-fastdiv.h>
|
||||
|
||||
// ggml-common.h must be included prior to this header
|
||||
|
||||
struct htp_spad {
|
||||
|
|
@ -74,6 +75,14 @@ struct htp_ops_context {
|
|||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
|
||||
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
|
||||
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
|
||||
|
||||
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
|
||||
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
|
||||
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
|
|
@ -88,5 +97,6 @@ int op_rope(struct htp_ops_context * octx);
|
|||
int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
|
|
|||
|
|
@ -0,0 +1,457 @@
|
|||
#ifndef HVX_ARITH_H
|
||||
#define HVX_ARITH_H
|
||||
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hex-utils.h"
|
||||
|
||||
//
|
||||
// Binary operations (add, mul, sub)
|
||||
//
|
||||
|
||||
#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src0_type * restrict vsrc0 = (src0_type *) src0; \
|
||||
src1_type * restrict vsrc1 = (src1_type *) src1; \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = vec_op(vsrc0[i], vsrc1[i]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
||||
#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
||||
#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
// ADD variants
|
||||
|
||||
static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
// SUB variants
|
||||
|
||||
static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
// MUL variants
|
||||
|
||||
static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
// Dispatchers
|
||||
|
||||
static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_add_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_add_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_add_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_add_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_sub_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_sub_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_sub_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_sub_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_mul_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_mul_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_mul_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_mul_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
// Mul-Mul Optimized
|
||||
|
||||
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
assert((unsigned long) src2 % 128 == 0);
|
||||
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0;
|
||||
HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1;
|
||||
HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2;
|
||||
|
||||
const uint32_t elem_size = sizeof(float);
|
||||
const uint32_t epv = 128 / elem_size;
|
||||
const uint32_t nvec = num_elems / epv;
|
||||
const uint32_t nloe = num_elems % epv;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
_Pragma("unroll(4)")
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
|
||||
vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
|
||||
HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]);
|
||||
hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
|
||||
}
|
||||
}
|
||||
|
||||
// Scalar Operations
|
||||
|
||||
#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector v = vsrc[i]; \
|
||||
vdst[i] = scalar_op_macro(v); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = vsrc[i]; \
|
||||
v = scalar_op_macro(v); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define HVX_OP_ADD_SCALAR(v) \
|
||||
({ \
|
||||
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
|
||||
HVX_Vector out = HVX_OP_ADD(v, val_vec); \
|
||||
Q6_V_vmux_QVV(pred_inf, inf, out); \
|
||||
})
|
||||
|
||||
#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec)
|
||||
#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec)
|
||||
|
||||
// Add Scalar Variants
|
||||
|
||||
static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
static const float kInf = INFINITY;
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
|
||||
}
|
||||
|
||||
// Sub Scalar Variants
|
||||
|
||||
static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
|
||||
}
|
||||
|
||||
// Mul Scalar Variants
|
||||
|
||||
static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
|
||||
hvx_add_scalar_f32_aa(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) dst, 128)) {
|
||||
hvx_add_scalar_f32_au(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_add_scalar_f32_ua(dst, src, val, num_elems);
|
||||
} else {
|
||||
hvx_add_scalar_f32_uu(dst, src, val, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
|
||||
hvx_mul_scalar_f32_aa(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) dst, 128)) {
|
||||
hvx_mul_scalar_f32_au(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_mul_scalar_f32_ua(dst, src, val, num_elems);
|
||||
} else {
|
||||
hvx_mul_scalar_f32_uu(dst, src, val, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sub_scalar_f32_aa(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) dst, 128)) {
|
||||
hvx_sub_scalar_f32_au(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sub_scalar_f32_ua(dst, src, val, num_elems);
|
||||
} else {
|
||||
hvx_sub_scalar_f32_uu(dst, src, val, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
// MIN Scalar variants
|
||||
|
||||
#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v)
|
||||
|
||||
static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
|
||||
hvx_min_scalar_f32_aa(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) dst, 128)) {
|
||||
hvx_min_scalar_f32_au(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_min_scalar_f32_ua(dst, src, val, num_elems);
|
||||
} else {
|
||||
hvx_min_scalar_f32_uu(dst, src, val, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
// CLAMP Scalar variants
|
||||
|
||||
#define HVX_OP_CLAMP_SCALAR(v) \
|
||||
({ \
|
||||
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \
|
||||
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \
|
||||
HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \
|
||||
Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \
|
||||
})
|
||||
|
||||
static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
|
||||
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
|
||||
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
|
||||
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
|
||||
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
|
||||
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
|
||||
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
|
||||
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
|
||||
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
|
||||
hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems);
|
||||
} else if (hex_is_aligned((void *) dst, 128)) {
|
||||
hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems);
|
||||
} else if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems);
|
||||
} else {
|
||||
hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
#undef HVX_OP_ADD
|
||||
#undef HVX_OP_SUB
|
||||
#undef HVX_OP_MUL
|
||||
#undef hvx_arith_loop_body
|
||||
#undef HVX_OP_ADD_SCALAR
|
||||
#undef HVX_OP_SUB_SCALAR
|
||||
#undef HVX_OP_MUL_SCALAR
|
||||
#undef hvx_scalar_loop_body
|
||||
#undef HVX_OP_MIN_SCALAR
|
||||
#undef HVX_OP_CLAMP_SCALAR
|
||||
|
||||
#endif // HVX_ARITH_H
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
#ifndef HVX_BASE_H
|
||||
#define HVX_BASE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
#include "hvx-types.h"
|
||||
|
||||
static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
|
||||
// Rotate as needed.
|
||||
v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
|
||||
|
||||
uint32_t left_off = (size_t) dst & 127;
|
||||
uint32_t right_off = left_off + n;
|
||||
|
||||
HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) dst);
|
||||
HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off);
|
||||
|
||||
if (right_off > 128) {
|
||||
Q6_vmem_QRIV(qr, (HVX_Vector *) dst + 1, v);
|
||||
// all 1's
|
||||
qr = Q6_Q_vcmp_eq_VbVb(v, v);
|
||||
}
|
||||
|
||||
ql_not = Q6_Q_or_QQn(ql_not, qr);
|
||||
Q6_vmem_QnRIV(ql_not, (HVX_Vector *) dst, v);
|
||||
}
|
||||
|
||||
static inline void hvx_vec_store_a(void * restrict dst, uint32_t n, HVX_Vector v) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
HVX_VectorPred m = Q6_Q_or_QQn(Q6_Q_vsetq_R((unsigned long) dst), Q6_Q_vsetq2_R(n));
|
||||
Q6_vmem_QnRIV(m, (HVX_Vector *) dst, v);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_f32(float v) {
|
||||
union { float f; uint32_t i; } u = { .f = v };
|
||||
return Q6_V_vsplat_R(u.i);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_f16(float v) {
|
||||
union { __fp16 f; uint16_t i; } u = { .f = v };
|
||||
return Q6_Vh_vsplat_R(u.i);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) {
|
||||
// vdelta control to replicate first 4 bytes across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
};
|
||||
|
||||
HVX_Vector ctrl = *(HVX_Vector *) repl;
|
||||
return Q6_V_vdelta_VV(v, ctrl);
|
||||
}
|
||||
|
||||
static inline float hvx_vec_get_f32(HVX_Vector v) {
|
||||
float __attribute__((aligned(128))) x;
|
||||
hvx_vec_store_a(&x, 4, v);
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
|
||||
// abs by clearing the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||
return Q6_V_vand_VV(v, mask);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_neg_f16(HVX_Vector v) {
|
||||
// neg by setting the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
|
||||
return Q6_V_vxor_VV(v, mask);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_f32(HVX_Vector v) {
|
||||
// abs by clearing the fp32 sign bit
|
||||
HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff);
|
||||
return Q6_V_vand_VV(v, mask);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_neg_f32(HVX_Vector v) {
|
||||
#if __HVX_ARCH__ > 75
|
||||
return Q6_Vsf_vfneg_Vsf(v);
|
||||
#else
|
||||
// neg by setting the fp32 sign bit
|
||||
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
|
||||
return Q6_V_vxor_VV(v, mask);
|
||||
#endif // __HVX_ARCH__ > 75
|
||||
}
|
||||
|
||||
static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
|
||||
const HVX_Vector vnan_exp = Q6_Vh_vsplat_R(0x7C00);
|
||||
const HVX_Vector vnan_frac = Q6_Vh_vsplat_R(0x7FFF);
|
||||
|
||||
// get pred of which are NaN, i.e., exponent bits all 1s and fraction bits non 0s
|
||||
HVX_VectorPred p_exp = Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_exp), vnan_exp);
|
||||
HVX_VectorPred p_frac = Q6_Q_not_Q(Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_frac), vnan_exp));
|
||||
return Q6_Q_and_QQ(p_exp, p_frac);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
|
||||
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
|
||||
HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)));
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
|
||||
const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
|
||||
HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
|
||||
v = Q6_V_vmux_QVV(nan, neg_inf, v);
|
||||
#endif
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
/* Q6_Vsf_equals_Vw is only available on v73+.*/
|
||||
#if __HVX_ARCH__ < 73
|
||||
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
|
||||
{
|
||||
HVX_Vector const vzero = Q6_V_vzero();
|
||||
HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
|
||||
HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
|
||||
HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
|
||||
HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
|
||||
HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
|
||||
HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
|
||||
{
|
||||
return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
|
||||
// This looks complicated.
|
||||
// Ideally should just be Q6_Vh_equals_Vhf(vin)
|
||||
// but that instruction does not do proper rounding.
|
||||
|
||||
// convert to qf32, multiplying by 1.0 in the process.
|
||||
HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00));
|
||||
|
||||
// 'in-range' values are +/32752.
|
||||
// add 192K to it, convert to sf
|
||||
HVX_Vector v192K = Q6_V_vsplat_R(0x48400000);
|
||||
HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K));
|
||||
HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K));
|
||||
|
||||
// for in-range cases, result is {163858... 229360} so the exponent is always 144.
|
||||
// if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer.
|
||||
// Start by <<10 to get the final 'sign' bit in bit 15...
|
||||
vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10);
|
||||
vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10);
|
||||
|
||||
// now round down to 16
|
||||
return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0);
|
||||
}
|
||||
|
||||
#endif /* HVX_BASE_H */
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
#ifndef HVX_COPY_H
|
||||
#define HVX_COPY_H
|
||||
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#define hvx_splat_loop_body(dst_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
\
|
||||
uint32_t nvec = n / (128 / elem_size); \
|
||||
uint32_t nloe = n % (128 / elem_size); \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = src; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, src); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_a(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_u(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
#define hvx_copy_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { vdst[i] = vsrc[i]; } \
|
||||
if (nloe) { \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, vsrc[i]); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
// Generic copy routines
|
||||
static inline void hvx_copy_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_copy_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_copy_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_copy_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_copy_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_copy_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_copy_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_copy_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
// copy n fp16 elements : source and destination are aligned to HVX Vector (128)
|
||||
static inline void hvx_copy_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_aa(dst, src, n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
// copy n fp16 elements : source is aligned, destination is potentially unaligned
|
||||
static inline void hvx_copy_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_au(dst, src, n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
// copy n fp16 elements : source is aligned, destination is potentially unaligned
|
||||
static inline void hvx_copy_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_ua(dst, src, n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
// copy n fp16 elements : source is aligned, destination is potentially unaligned
|
||||
static inline void hvx_copy_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_uu(dst, src, n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source and destination are aligned to HVX Vector (128)
|
||||
static inline void hvx_copy_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_aa(dst, src, n, sizeof(float));
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source is aligned, destination is unaligned
|
||||
static inline void hvx_copy_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_ua(dst, src, n, sizeof(float));
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source is unaligned, destination is aligned
|
||||
static inline void hvx_copy_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_au(dst, src, n, sizeof(float));
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source is unaligned, destination unaligned
|
||||
static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_uu(dst, src, n, sizeof(float));
|
||||
}
|
||||
|
||||
//// fp32 -> fp16
|
||||
|
||||
#define hvx_copy_f16_f32_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0); \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(__fp16); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is aligned
|
||||
static inline void hvx_copy_f16_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
|
||||
static inline void hvx_copy_f16_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
|
||||
static inline void hvx_copy_f16_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
|
||||
static inline void hvx_copy_f16_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
//// fp16 -> fp32
|
||||
|
||||
#define hvx_copy_f32_f16_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const HVX_Vector one = hvx_vec_splat_f16(1.0); \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(__fp16); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (i = 0; i < nvec; ++i) { \
|
||||
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \
|
||||
vdst[i*2] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)); \
|
||||
vdst[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)); \
|
||||
} \
|
||||
\
|
||||
if (nloe) { \
|
||||
HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \
|
||||
\
|
||||
HVX_Vector vd = Q6_V_lo_W(p); \
|
||||
i = 2 * i; \
|
||||
\
|
||||
if (nloe >= 32) { \
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(vd); \
|
||||
nloe -= 32; ++i; vd = Q6_V_hi_W(p); \
|
||||
} \
|
||||
\
|
||||
if (nloe) { \
|
||||
vd = Q6_Vsf_equals_Vqf32(vd); \
|
||||
hvx_vec_store_u(&vdst[i], nloe * sizeof(float), vd); \
|
||||
} \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is aligned
|
||||
static inline void hvx_copy_f32_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is aligned
|
||||
static inline void hvx_copy_f32_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is unaligned
|
||||
static inline void hvx_copy_f32_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is unaligned
|
||||
static inline void hvx_copy_f32_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
#endif // HVX_COPY_H
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
#ifndef HVX_DUMP_H
|
||||
#define HVX_DUMP_H
|
||||
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
#include "hvx-types.h"
|
||||
|
||||
static void hvx_vec_dump_f16_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||
HVX_VectorAlias u = { .v = v };
|
||||
|
||||
const uint32_t n0 = n / 16;
|
||||
const uint32_t n1 = n % 16;
|
||||
int i = 0;
|
||||
for (; i < n0; i++) {
|
||||
hex_dump_f16_line(pref, u.fp16 + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
hex_dump_f16_line(pref, u.fp16 + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_f16(char * pref, HVX_Vector v) {
|
||||
hvx_vec_dump_f16_n(pref, v, 64);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
float d[32];
|
||||
} u = { .v = v };
|
||||
|
||||
const uint32_t n0 = n / 16;
|
||||
const uint32_t n1 = n % 16;
|
||||
int i = 0;
|
||||
for (; i < n0; i++) {
|
||||
hex_dump_f32_line(pref, u.d + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
hex_dump_f32_line(pref, u.d + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_f32_hmt(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
float d[32];
|
||||
} u = { .v = v };
|
||||
|
||||
FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1],
|
||||
u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_f32(char * pref, HVX_Vector v) {
|
||||
hvx_vec_dump_f32_n(pref, v, 32);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_int32(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
int32_t d[32];
|
||||
} u = { .v = v };
|
||||
|
||||
for (int i = 0; i < 32 / 16; i++) {
|
||||
hex_dump_int32_line(pref, u.d + (16 * i), 16);
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
int32_t d[32];
|
||||
} u = { .v = v };
|
||||
|
||||
FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12],
|
||||
u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
int8_t d[128];
|
||||
} u = { .v = v };
|
||||
|
||||
FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60],
|
||||
u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_int8(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
int8_t d[128];
|
||||
} u = { .v = v };
|
||||
|
||||
for (int i = 0; i < 128 / 16; i++) {
|
||||
hex_dump_int8_line(pref, u.d + (16 * i), 16);
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
uint8_t d[128];
|
||||
} u = { .v = v };
|
||||
|
||||
for (int i = 0; i < 128 / 16; i++) {
|
||||
hex_dump_uint8_line(pref, u.d + (16 * i), 16);
|
||||
}
|
||||
}
|
||||
|
||||
static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) {
|
||||
typedef union {
|
||||
HVX_Vector v;
|
||||
int8_t d[128];
|
||||
} U;
|
||||
|
||||
U u0 = { .v = v0 };
|
||||
U u1 = { .v = v1 };
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
if (u0.d[i] != u1.d[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif /* HVX_DUMP_H */
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
|
||||
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
|
||||
|
||||
HVX_Vector out = hvx_vec_exp_fp32(in_vec);
|
||||
|
||||
return Q6_V_vmux_QVV(pred0, inf, out);
|
||||
}
|
||||
|
||||
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_exp_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_exp_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.02f; // log(INF)
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
|
||||
*p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
|
||||
} else {
|
||||
*p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
|
||||
} else {
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||
|
||||
vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
|
||||
} else {
|
||||
vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf);
|
||||
}
|
||||
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
#ifndef HVX_EXP_H
|
||||
#define HVX_EXP_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hvx-floor.h"
|
||||
|
||||
#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!)
|
||||
#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!)
|
||||
#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!)
|
||||
#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!)
|
||||
#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!)
|
||||
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
|
||||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x41a00000) // 20.0
|
||||
#define EXP_RANGE_L (0xc1a00000) // -20.0
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
||||
HVX_Vector z_qf32_v;
|
||||
HVX_Vector x_v;
|
||||
HVX_Vector x_qf32_v;
|
||||
HVX_Vector y_v;
|
||||
HVX_Vector k_v;
|
||||
HVX_Vector f_v;
|
||||
HVX_Vector epsilon_v;
|
||||
HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);
|
||||
HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);
|
||||
HVX_Vector E_const;
|
||||
HVX_Vector zero_v = Q6_V_vzero();
|
||||
|
||||
// exp(x) is approximated as follows:
|
||||
// f = floor(x/ln(2)) = floor(x*log2(e))
|
||||
// epsilon = x - f*ln(2)
|
||||
// exp(x) = exp(epsilon+f*ln(2))
|
||||
// = exp(epsilon)*exp(f*ln(2))
|
||||
// = exp(epsilon)*2^f
|
||||
//
|
||||
// Since epsilon is close to zero, it can be approximated with its Taylor series:
|
||||
// exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...
|
||||
// Preserving the first eight elements, we get:
|
||||
// exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7
|
||||
// = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2
|
||||
|
||||
HVX_Vector temp_v = in_vec;
|
||||
|
||||
// Clamp inputs to (-20.0, 20.0)
|
||||
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
|
||||
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
|
||||
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
|
||||
|
||||
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
|
||||
epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
|
||||
|
||||
// f_v is the floating point result and k_v is the integer result
|
||||
f_v = hvx_vec_floor_f32(epsilon_v);
|
||||
k_v = hvx_vec_truncate_f32(f_v);
|
||||
|
||||
x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);
|
||||
|
||||
// x = x - f_v * logn2;
|
||||
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);
|
||||
x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);
|
||||
// normalize before every QFloat's vmpy
|
||||
x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
|
||||
|
||||
// z = x * x;
|
||||
z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
|
||||
z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
|
||||
|
||||
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
|
||||
|
||||
// y = E4 + E5 * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_5);
|
||||
y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_4);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = E3 + y * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_3);
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = E2 + y * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_2);
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = E1 + y * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_1);
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = E0 + y * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_0);
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = x + y * z;
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = y + 1.0;
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));
|
||||
|
||||
// insert exponents
|
||||
// y = ldexpf(y, k);
|
||||
// y_v += k_v; // qf32
|
||||
// modify exponent
|
||||
|
||||
y_v = Q6_Vsf_equals_Vqf32(y_v);
|
||||
|
||||
// add k_v to the exponent of y_v
|
||||
HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);
|
||||
|
||||
y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);
|
||||
y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);
|
||||
|
||||
// exponent cannot be negative; if overflow is detected, result is set to zero
|
||||
HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);
|
||||
|
||||
y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);
|
||||
|
||||
y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);
|
||||
|
||||
return y_v;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
|
||||
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
|
||||
|
||||
HVX_Vector out = hvx_vec_exp_f32(in_vec);
|
||||
|
||||
return Q6_V_vmux_QVV(pred0, inf, out);
|
||||
}
|
||||
|
||||
static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == hex_is_aligned((void *) src, VLEN)) || (0 == hex_is_aligned((void *) dst, VLEN))) {
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
}
|
||||
|
||||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.02f; // log(INF)
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_f32(*p_vec_in1++);
|
||||
*p_vec_out++ = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
|
||||
} else {
|
||||
*p_vec_out++ = hvx_vec_exp_f32_guard(*p_vec_in1++, max_exp, inf);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
|
||||
} else {
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(in, max_exp, inf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);
|
||||
|
||||
vec_out = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
|
||||
} else {
|
||||
vec_out = hvx_vec_exp_f32_guard(in, max_exp, inf);
|
||||
}
|
||||
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HVX_EXP_H */
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
#ifndef HVX_FLOOR_H
|
||||
#define HVX_FLOOR_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#define IEEE_VSF_EXPLEN (8)
|
||||
#define IEEE_VSF_EXPBIAS (127)
|
||||
#define IEEE_VSF_EXPMASK (0xFF)
|
||||
#define IEEE_VSF_MANTLEN (23)
|
||||
#define IEEE_VSF_MANTMASK (0x7FFFFF)
|
||||
#define IEEE_VSF_MIMPMASK (0x800000)
|
||||
|
||||
static inline HVX_Vector hvx_vec_truncate_f32(HVX_Vector in_vec) {
|
||||
HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
|
||||
HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
|
||||
HVX_Vector const_zero_v = Q6_V_vzero();
|
||||
|
||||
HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
|
||||
|
||||
HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
|
||||
expval_v &= IEEE_VSF_EXPMASK;
|
||||
expval_v -= IEEE_VSF_EXPBIAS;
|
||||
|
||||
// negative exp == fractional value
|
||||
HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
|
||||
|
||||
HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift
|
||||
|
||||
HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa
|
||||
HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0
|
||||
|
||||
vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer
|
||||
vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0
|
||||
|
||||
HVX_Vector neg_vout = -vout;
|
||||
|
||||
vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives
|
||||
|
||||
return (vout);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_floor_f32(HVX_Vector in_vec) {
|
||||
HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
|
||||
HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
|
||||
HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN);
|
||||
HVX_Vector const_zero_v = Q6_V_vzero();
|
||||
HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf
|
||||
|
||||
HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
|
||||
|
||||
HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
|
||||
expval_v &= IEEE_VSF_EXPMASK;
|
||||
expval_v -= IEEE_VSF_EXPBIAS;
|
||||
|
||||
HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
|
||||
HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v);
|
||||
HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v);
|
||||
HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec);
|
||||
|
||||
// if expval < 0 (q_negexp) // <0, floor is 0
|
||||
// if vin > 0
|
||||
// floor = 0
|
||||
// if vin < 0
|
||||
// floor = -1
|
||||
// if expval < mant_len (q_expltmn) // >0, but fraction may exist
|
||||
// get sign (q_negative)
|
||||
// mask >> expval // fraction bits to mask off
|
||||
// vout = ~(mask) // apply mask to remove fraction
|
||||
// if (qneg) // negative floor is one less (more, sign bit for neg)
|
||||
// vout += ((impl_mask) >> expval)
|
||||
// if (mask && vin)
|
||||
// vout = vin
|
||||
// else // already an integer
|
||||
// ; // no change
|
||||
|
||||
// compute floor
|
||||
mask_mant_v >>= expval_v;
|
||||
HVX_Vector neg_addin_v = mask_impl_v >> expval_v;
|
||||
HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v);
|
||||
HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec);
|
||||
|
||||
HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set
|
||||
HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v);
|
||||
|
||||
HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear
|
||||
HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits
|
||||
|
||||
vout = in_vec;
|
||||
vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval<mant
|
||||
vout = Q6_V_vmux_QVV(q_integral, in_vec, vout); // integral values
|
||||
vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout); // expval<0 x>0 -> 0
|
||||
vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1
|
||||
|
||||
return vout;
|
||||
}
|
||||
|
||||
#endif /* HVX_FLOOR_H */
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
|
||||
HVX_Vector out = hvx_vec_inverse_fp32(v_sf);
|
||||
|
||||
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
|
||||
const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);
|
||||
|
||||
return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
|
||||
}
|
||||
|
||||
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_inverse_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
static const uint32_t kNanInfMask = 0x7f800000;
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(kNanInfMask);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * p_vec_in = (HVX_Vector *) src;
|
||||
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
*p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);
|
||||
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
#ifndef HVX_INVERSE_H
|
||||
#define HVX_INVERSE_H
|
||||
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
// ====================================================
|
||||
// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5
|
||||
// Order:3; continuity: True; Ends forced: True
|
||||
// Mode: unsigned; Result fractional bits: 14
|
||||
// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05
|
||||
// 32769 -32706 31252 -10589
|
||||
// 32590 -30635 22793 -4493
|
||||
// 32066 -27505 16481 -2348
|
||||
// 31205 -24054 11849 -1306
|
||||
|
||||
static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) {
|
||||
// input is 0..0xffff representing 0.0 .. 1.0
|
||||
HVX_Vector p;
|
||||
p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull);
|
||||
p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull);
|
||||
p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull);
|
||||
p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull);
|
||||
return p; // signed result, 14 fractional bits
|
||||
}
|
||||
|
||||
// Find reciprocal of fp16.
|
||||
// (1) first, convert to fp32, multiplying by 1.0; this is done to
|
||||
// handle denormals. Ignoring sign and zero, result should be at
|
||||
// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000)
|
||||
// (exponent in range [103,143])
|
||||
// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly
|
||||
// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32
|
||||
// (4) convert that to fp16
|
||||
// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace
|
||||
// the result with the max value.
|
||||
static inline HVX_Vector hvx_vec_inverse_f16(HVX_Vector vals) {
|
||||
HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
|
||||
HVX_Vector avals = Q6_V_vand_VV(vals, em_mask);
|
||||
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals);
|
||||
// is too small to 1/x ? for 'standard' fp16, this would be 0x101
|
||||
HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals);
|
||||
|
||||
HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0
|
||||
HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32));
|
||||
HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32));
|
||||
|
||||
// bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector
|
||||
HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9));
|
||||
// likewise extract the upper 16 from each, containing the exponents in range 103..142
|
||||
HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0);
|
||||
//Get exponent in IEEE 32-bit representation
|
||||
exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7);
|
||||
|
||||
// so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane
|
||||
// We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0)
|
||||
// Use poly to transform to 1/x, with 14 fractional bits
|
||||
//
|
||||
HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16);
|
||||
|
||||
HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros
|
||||
|
||||
// Get mantissa for 16-bit represenation
|
||||
HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF));
|
||||
|
||||
//Compute Reciprocal Exponent
|
||||
HVX_Vector exp_recip =
|
||||
Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1)));
|
||||
//Convert it for 16-bit representation
|
||||
exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15));
|
||||
exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10);
|
||||
|
||||
//Merge exponent and mantissa for reciprocal
|
||||
HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip);
|
||||
// map 'small' inputs to standard largest value 0x7bff
|
||||
recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip);
|
||||
// add sign back
|
||||
recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000);
|
||||
return recip;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_inverse_f32(HVX_Vector v_sf) {
|
||||
HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3);
|
||||
HVX_Vector two_sf = hvx_vec_splat_f32(2.0);
|
||||
|
||||
// First approximation
|
||||
HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf);
|
||||
|
||||
HVX_Vector r_qf;
|
||||
|
||||
// Refine
|
||||
r_qf = Q6_Vqf32_vmpy_VsfVsf(
|
||||
i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf)))));
|
||||
r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
|
||||
r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
|
||||
r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
|
||||
r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
|
||||
|
||||
return Q6_Vsf_equals_Vqf32(r_qf);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
|
||||
HVX_Vector out = hvx_vec_inverse_f32(v_sf);
|
||||
|
||||
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
|
||||
const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);
|
||||
|
||||
return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
|
||||
}
|
||||
|
||||
#define hvx_inverse_f32_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP32; \
|
||||
const uint32_t nloe = n % VLEN_FP32; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) {
|
||||
if ((unsigned long) dst % 128 == 0) {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_inverse_f32_aa(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_inverse_f32_au(dst, src, num_elems);
|
||||
}
|
||||
} else {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_inverse_f32_ua(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_inverse_f32_uu(dst, src, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HVX_INVERSE_H
|
||||
|
|
@ -0,0 +1,225 @@
|
|||
#ifndef HVX_REDUCE_H
|
||||
#define HVX_REDUCE_H
|
||||
|
||||
#include <math.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <assert.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
#include "hvx-base.h"
|
||||
#include "hvx-types.h"
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) {
|
||||
unsigned int total = n * 4; // total vec nbytes
|
||||
unsigned int width = 4; // int32
|
||||
|
||||
HVX_Vector sum = in, sum_t;
|
||||
while (width < total) {
|
||||
sum_t = Q6_V_vror_VR(sum, width); // rotate right
|
||||
sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum
|
||||
width = width << 1;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) {
|
||||
return hvx_vec_reduce_sum_n_i32(in, 32);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) {
|
||||
unsigned int total = n * 4; // total vec nbytes
|
||||
unsigned int width = 4; // fp32 nbytes
|
||||
|
||||
HVX_Vector sum = in, sum_t;
|
||||
while (width < total) {
|
||||
sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum
|
||||
width = width << 1;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
|
||||
return hvx_vec_reduce_sum_n_qf32(in, 32);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
|
||||
unsigned int total = n * 4; // total vec nbytes
|
||||
unsigned int width = 4; // fp32 nbytes
|
||||
|
||||
HVX_Vector sum = in, sum_t;
|
||||
while (width < total) {
|
||||
sum_t = Q6_V_vror_VR(sum, width); // rotate right
|
||||
sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
|
||||
width = width << 1;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
|
||||
return hvx_vec_reduce_sum_n_f32(in, 32);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) {
|
||||
unsigned total = 128; // total vec nbytes
|
||||
unsigned width = 2; // fp16 nbytes
|
||||
|
||||
HVX_Vector _max = in, _max_t;
|
||||
while (width < total) {
|
||||
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||
_max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
|
||||
width = width << 1;
|
||||
}
|
||||
|
||||
return _max;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) {
|
||||
unsigned total = 128; // total vec nbytes
|
||||
unsigned width = 2; // fp32 nbytes
|
||||
|
||||
HVX_Vector _max_t;
|
||||
|
||||
_max = Q6_Vhf_vmax_VhfVhf(in, _max);
|
||||
while (width < total) {
|
||||
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||
_max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
|
||||
width = width << 1;
|
||||
}
|
||||
|
||||
return _max;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) {
|
||||
unsigned total = 128; // total vec nbytes
|
||||
unsigned width = 4; // fp32 nbytes
|
||||
|
||||
HVX_Vector _max = in, _max_t;
|
||||
while (width < total) {
|
||||
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||
_max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
|
||||
width = width << 1;
|
||||
}
|
||||
|
||||
return _max;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) {
|
||||
unsigned total = 128; // total vec nbytes
|
||||
unsigned width = 4; // fp32 nbytes
|
||||
|
||||
HVX_Vector _max_t;
|
||||
|
||||
_max = Q6_Vsf_vmax_VsfVsf(in, _max);
|
||||
while (width < total) {
|
||||
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||
_max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
|
||||
width = width << 1;
|
||||
}
|
||||
|
||||
return _max;
|
||||
}
|
||||
|
||||
#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \
|
||||
do { \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
HVX_Vector acc = init_vec; \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = num_elems / epv; \
|
||||
const uint32_t nloe = num_elems % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
acc = vec_op(acc, vsrc[i]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
const float * srcf = (const float *) src + i * epv; \
|
||||
HVX_Vector in = *(HVX_UVector *) srcf; \
|
||||
HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size); \
|
||||
acc = vec_op(acc, temp); \
|
||||
} \
|
||||
HVX_Vector v = reduce_op(acc); \
|
||||
return scalar_reduce(v); \
|
||||
} while(0)
|
||||
|
||||
#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val)
|
||||
#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val)
|
||||
#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val))
|
||||
#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v)
|
||||
#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v))
|
||||
|
||||
// Max variants
|
||||
|
||||
static inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) {
|
||||
HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
|
||||
}
|
||||
|
||||
static inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) {
|
||||
HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
|
||||
hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
|
||||
}
|
||||
|
||||
static inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) {
|
||||
if (hex_is_aligned((void *) src, 128)) {
|
||||
return hvx_reduce_max_f32_a(src, num_elems);
|
||||
} else {
|
||||
return hvx_reduce_max_f32_u(src, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
// Sum variants
|
||||
|
||||
static inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) {
|
||||
HVX_Vector init_vec = Q6_V_vsplat_R(0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
|
||||
}
|
||||
|
||||
static inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) {
|
||||
HVX_Vector init_vec = Q6_V_vsplat_R(0);
|
||||
hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
|
||||
}
|
||||
|
||||
static inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) {
|
||||
if (hex_is_aligned((void *) src, 128)) {
|
||||
return hvx_reduce_sum_f32_a(src, num_elems);
|
||||
} else {
|
||||
return hvx_reduce_sum_f32_u(src, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
// Sum of squares variants
|
||||
|
||||
static inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) {
|
||||
HVX_Vector init_vec = Q6_V_vsplat_R(0);
|
||||
assert((uintptr_t) src % 128 == 0);
|
||||
hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
|
||||
}
|
||||
|
||||
static inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) {
|
||||
HVX_Vector init_vec = Q6_V_vsplat_R(0);
|
||||
hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
|
||||
}
|
||||
|
||||
static inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
|
||||
if (hex_is_aligned((void *) src, 128)) {
|
||||
return hvx_sum_of_squares_f32_a(src, num_elems);
|
||||
} else {
|
||||
return hvx_sum_of_squares_f32_u(src, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
#undef hvx_reduce_loop_body
|
||||
#undef HVX_REDUCE_MAX_OP
|
||||
#undef HVX_REDUCE_SUM_OP
|
||||
#undef HVX_REDUCE_MAX_SCALAR
|
||||
#undef HVX_REDUCE_SUM_SCALAR
|
||||
#undef HVX_SUM_SQ_OP
|
||||
|
||||
#endif /* HVX_REDUCE_H */
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
#ifndef HVX_SCALE_H
|
||||
#define HVX_SCALE_H
|
||||
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#define hvx_scale_f32_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
HVX_Vector vs = hvx_vec_splat_f32(scale); \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; ++i) { \
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
assert((size_t) dst % 128 == 0);
|
||||
assert((size_t) src % 128 == 0);
|
||||
hvx_scale_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
assert((size_t) dst % 128 == 0);
|
||||
hvx_scale_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
assert((size_t) src % 128 == 0);
|
||||
hvx_scale_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
hvx_scale_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
if (((size_t) dst & 127) == 0) {
|
||||
if (((size_t) src & 127) == 0) {
|
||||
hvx_scale_f32_aa(dst, src, n, scale);
|
||||
} else {
|
||||
hvx_scale_f32_au(dst, src, n, scale);
|
||||
}
|
||||
} else {
|
||||
if (((size_t) src & 127) == 0) {
|
||||
hvx_scale_f32_ua(dst, src, n, scale);
|
||||
} else {
|
||||
hvx_scale_f32_uu(dst, src, n, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define hvx_scale_offset_f32_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
HVX_Vector vs = hvx_vec_splat_f32(scale); \
|
||||
HVX_Vector vo = hvx_vec_splat_f32(offset); \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; ++i) { \
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
assert((size_t) dst % 128 == 0);
|
||||
assert((size_t) src % 128 == 0);
|
||||
hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
assert((size_t) dst % 128 == 0);
|
||||
hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
assert((size_t) src % 128 == 0);
|
||||
hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
if (((size_t) dst & 127) == 0) {
|
||||
if (((size_t) src & 127) == 0) {
|
||||
hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
|
||||
} else {
|
||||
hvx_scale_offset_f32_au(dst, src, n, scale, offset);
|
||||
}
|
||||
} else {
|
||||
if (((size_t) src & 127) == 0) {
|
||||
hvx_scale_offset_f32_ua(dst, src, n, scale, offset);
|
||||
} else {
|
||||
hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HVX_SCALE_H
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#if 0
|
||||
// Reference algo used in hvx-utils
|
||||
static void fast_sigmoid_f32(const float* restrict src, float* restrict dst, const int num_elems)
|
||||
{
|
||||
const float c1 = 0.03138777;
|
||||
const float c2 = 0.276281267;
|
||||
const float c_log2f = 1.442695022;
|
||||
|
||||
int32_t store_ints[32];
|
||||
float store_floats[3][32];
|
||||
|
||||
for (int i = 0; i < num_elems; i++)
|
||||
{
|
||||
float v = src0[i];
|
||||
|
||||
v *= c_log2f*0.5;
|
||||
int intPart = (int)v;
|
||||
float x = (v - intPart);
|
||||
float xx = x * x;
|
||||
float v1 = c_log2f + c2 * xx;
|
||||
float v2 = x + xx * c1 * x;
|
||||
float v3 = (v2 + v1);
|
||||
*((int*)&v3) += intPart << 24;
|
||||
float v4 = v2 - v1;
|
||||
float v5 = v3 - v4;
|
||||
float res = v3 / v5;
|
||||
|
||||
dst[i] = res;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
#ifndef HVX_SIGMOID_H
|
||||
#define HVX_SIGMOID_H
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
||||
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
||||
#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
|
||||
#define FAST_SIGMOID_C3 (0x3f000000) // 0.5
|
||||
|
||||
static inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) {
|
||||
v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
|
||||
v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
|
||||
|
||||
HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v));
|
||||
HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
|
||||
HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
|
||||
|
||||
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
|
||||
v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
|
||||
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
|
||||
v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
|
||||
v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
|
||||
|
||||
HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
|
||||
HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
|
||||
v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
|
||||
v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
|
||||
v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
|
||||
|
||||
HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
|
||||
HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
|
||||
|
||||
HVX_Vector res = hvx_vec_inverse_f32(v5);
|
||||
res = Q6_Vqf32_vmpy_VsfVsf(v3, res);
|
||||
|
||||
return Q6_Vsf_equals_Vqf32(res);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v,
|
||||
HVX_Vector one,
|
||||
HVX_Vector max_exp,
|
||||
HVX_Vector min_exp) {
|
||||
const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);
|
||||
const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);
|
||||
|
||||
HVX_Vector out = hvx_vec_fast_sigmoid_f32(v);
|
||||
out = Q6_V_vmux_QVV(pred_max, out, one);
|
||||
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
|
||||
// tanh(x) = 2 * sigmoid(2x) - 1
|
||||
HVX_Vector two = hvx_vec_splat_f32(2.0f);
|
||||
HVX_Vector one = hvx_vec_splat_f32(1.0f);
|
||||
HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
|
||||
|
||||
HVX_Vector max_exp = hvx_vec_splat_f32(87.f);
|
||||
HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);
|
||||
|
||||
HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
|
||||
|
||||
HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
|
||||
res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
|
||||
return Q6_Vsf_equals_Vqf32(res);
|
||||
}
|
||||
|
||||
#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const HVX_Vector one = hvx_vec_splat_f32(1.f); \
|
||||
const HVX_Vector max_exp = hvx_vec_splat_f32(87.f); \
|
||||
const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); \
|
||||
\
|
||||
const uint32_t epv = 128 / sizeof(float); \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
|
||||
vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
#endif /* HVX_SIGMOID_H */
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
#ifndef HVX_SQRT_H
|
||||
#define HVX_SQRT_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation
|
||||
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
|
||||
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
|
||||
|
||||
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
|
||||
//Algorithm :
|
||||
// x2 = input*0.5
|
||||
// y = * (long *) &input
|
||||
// y = 0x5f3759df - (y>>2)
|
||||
// y = y*(threehalfs - x2*y*y)
|
||||
|
||||
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
|
||||
HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF);
|
||||
HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
|
||||
|
||||
HVX_Vector x2, y, ypower2, temp;
|
||||
|
||||
x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
|
||||
x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
|
||||
|
||||
y = Q6_Vw_vasr_VwR(in_vec, 1);
|
||||
y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
|
||||
|
||||
// 1st iteration
|
||||
ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
|
||||
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
|
||||
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
|
||||
temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
|
||||
|
||||
// 2nd iteration
|
||||
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
|
||||
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
|
||||
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
|
||||
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
|
||||
|
||||
// 3rd iteration
|
||||
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
|
||||
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
|
||||
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
|
||||
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
|
||||
|
||||
return Q6_Vsf_equals_Vqf32(temp);
|
||||
}
|
||||
|
||||
#endif /* HVX_SQRT_H */
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
#ifndef HVX_TYPES_H
|
||||
#define HVX_TYPES_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#define SIZEOF_FP32 (4)
|
||||
#define SIZEOF_FP16 (2)
|
||||
#define VLEN (128)
|
||||
#define VLEN_FP32 (VLEN / SIZEOF_FP32)
|
||||
#define VLEN_FP16 (VLEN / SIZEOF_FP16)
|
||||
|
||||
typedef union {
|
||||
HVX_Vector v;
|
||||
uint8_t b[VLEN];
|
||||
uint16_t h[VLEN_FP16];
|
||||
uint32_t w[VLEN_FP32];
|
||||
__fp16 fp16[VLEN_FP16];
|
||||
float fp32[VLEN_FP32];
|
||||
} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;
|
||||
|
||||
typedef struct {
|
||||
HVX_Vector v[2];
|
||||
} HVX_Vector_x2;
|
||||
|
||||
typedef struct {
|
||||
HVX_Vector v[4];
|
||||
} HVX_Vector_x4;
|
||||
|
||||
typedef struct {
|
||||
HVX_Vector v[8];
|
||||
} HVX_Vector_x8;
|
||||
|
||||
#endif /* HVX_TYPES_H */
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,17 +1,13 @@
|
|||
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
|
||||
#define FARF_ERROR 1
|
||||
#define FARF_HIGH 1
|
||||
#define FARF_MEDIUM 0
|
||||
#define FARF_LOW 0
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <AEEStdErr.h>
|
||||
#include <dspqueue.h>
|
||||
#include <HAP_compute_res.h>
|
||||
#include <HAP_etm_config.h>
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_power.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <qurt.h>
|
||||
|
|
@ -19,13 +15,14 @@
|
|||
#include <remote.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hex-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "ops-utils.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
||||
|
|
@ -362,14 +359,14 @@ struct profile_data {
|
|||
|
||||
static inline void profile_start(struct profile_data * d) {
|
||||
d->usecs = HAP_perf_get_qtimer_count();
|
||||
d->cycles = htp_get_cycles();
|
||||
d->pkts = htp_get_pktcnt();
|
||||
d->cycles = hex_get_cycles();
|
||||
d->pkts = hex_get_pktcnt();
|
||||
}
|
||||
|
||||
static inline void profile_stop(struct profile_data * d) {
|
||||
d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
|
||||
d->cycles = htp_get_cycles() - d->cycles;
|
||||
d->pkts = htp_get_pktcnt() - d->pkts;
|
||||
d->cycles = hex_get_cycles() - d->cycles;
|
||||
d->pkts = hex_get_pktcnt() - d->pkts;
|
||||
}
|
||||
|
||||
static int send_htp_rsp(struct htp_context * c,
|
||||
|
|
@ -443,6 +440,43 @@ static void proc_matmul_req(struct htp_context * ctx,
|
|||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_cpy(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
|
|
@ -993,6 +1027,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
proc_get_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_CPY:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad cpy-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_cpy_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -3,28 +3,20 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define MM_SPAD_SRC0_NROWS 16
|
||||
#define MM_SPAD_SRC1_NROWS 16
|
||||
|
|
@ -36,20 +28,8 @@ struct htp_matmul_type {
|
|||
void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
HVX_Vector v[2];
|
||||
} HVX_Vector_x2;
|
||||
|
||||
typedef struct {
|
||||
HVX_Vector v[4];
|
||||
} HVX_Vector_x4;
|
||||
|
||||
typedef struct {
|
||||
HVX_Vector v[8];
|
||||
} HVX_Vector_x8;
|
||||
|
||||
// vdelta control to replicate first 4x fp32 values across lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = {
|
||||
static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
|
||||
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
||||
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
|
||||
|
|
@ -60,7 +40,7 @@ static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = {
|
|||
};
|
||||
|
||||
// vdelta control to replicate and interleave first 8x fp32 values across lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] = {
|
||||
static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
|
||||
0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
||||
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
|
||||
|
|
@ -71,7 +51,7 @@ static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128]
|
|||
};
|
||||
|
||||
// vdelta control to replicate first fp32 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = {
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
|
||||
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
|
||||
|
|
@ -82,7 +62,7 @@ static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = {
|
|||
};
|
||||
|
||||
// vdelta control to replicate first fp16 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
|
||||
0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
|
||||
0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
|
||||
|
|
@ -93,7 +73,7 @@ static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
|
|||
};
|
||||
|
||||
// vdelta control to replicate first fp16 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = {
|
||||
static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
|
|
@ -129,7 +109,7 @@ static inline size_t q8x4x2_row_size(uint32_t ne) {
|
|||
// ensures perfect alignment of quants and full row
|
||||
const uint32_t qk = QK_Q8_0x4x2;
|
||||
const uint32_t nb = (ne + qk - 1) / qk;
|
||||
return htp_round_up(ne + nb * 8 * sizeof(__fp16), 128);
|
||||
return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
|
||||
}
|
||||
|
||||
static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
|
||||
|
|
@ -389,7 +369,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|||
}
|
||||
|
||||
// Reduce and convert into fp32
|
||||
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
|
||||
hvx_vec_store_u(&s[0], 4, r0_sum);
|
||||
}
|
||||
|
|
@ -485,8 +465,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
|
|||
}
|
||||
|
||||
// Convert into fp32 and reduce
|
||||
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
|
||||
|
||||
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
|
||||
|
|
@ -562,7 +542,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|||
}
|
||||
|
||||
// Reduce and convert into fp32
|
||||
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
|
||||
hvx_vec_store_u(&s[0], 4, r0_sum);
|
||||
}
|
||||
|
|
@ -658,8 +638,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
|
|||
}
|
||||
|
||||
// Convert into fp32 and reduce
|
||||
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
|
||||
|
||||
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
|
||||
|
|
@ -768,7 +748,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
|
|||
}
|
||||
|
||||
// Reduce and convert into fp32
|
||||
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
|
||||
hvx_vec_store_u(&s[0], 4, r0_sum);
|
||||
}
|
||||
|
|
@ -900,8 +880,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
|
|||
}
|
||||
|
||||
// Convert into fp32 and reduce
|
||||
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
|
||||
|
||||
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
|
||||
|
|
@ -933,7 +913,7 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res
|
|||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
|
|
@ -977,8 +957,8 @@ static void vec_dot_f16_f16_aa_rx2(const int n,
|
|||
rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
|
||||
}
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1));
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum1));
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
|
||||
|
||||
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
|
||||
|
|
@ -1010,7 +990,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res
|
|||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
|
|
@ -1062,7 +1042,7 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
|
|||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
|
|
@ -1359,7 +1339,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|||
mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
|
||||
}
|
||||
|
||||
hvx_copy_fp32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
|
||||
hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
|
|
@ -1411,7 +1391,7 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|||
const size_t src0_row_size = nb01;
|
||||
const size_t src1_row_size = q8x4x2_row_size(ne10);
|
||||
|
||||
const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
|
||||
const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
|
||||
|
||||
// Per-thread VTCM scratchpads for all tensors
|
||||
// Note that the entire src1 tensor is already in VTCM
|
||||
|
|
@ -1524,7 +1504,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|||
const size_t src0_row_size = nb01;
|
||||
const size_t src1_row_size = q8x4x2_row_size(ne10);
|
||||
|
||||
const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
|
||||
const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
|
||||
|
||||
const uint32_t n_aids = src2->ne[0]; // num activated experts
|
||||
const uint32_t n_ids = ne02; // num experts
|
||||
|
|
@ -1590,7 +1570,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|||
|
||||
// *** dynamic quant
|
||||
|
||||
static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
||||
static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
||||
assert((unsigned long) x % 128 == 0);
|
||||
assert((unsigned long) y_q % 128 == 0);
|
||||
|
||||
|
|
@ -1598,10 +1578,10 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
|
|||
HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
// Use reduce max fp32 to find max(abs(e)) first
|
||||
HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0]));
|
||||
HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1]));
|
||||
HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2]));
|
||||
HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3]));
|
||||
HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
|
||||
HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
|
||||
HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
|
||||
HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
|
||||
// Load and convert into QF32
|
||||
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
|
||||
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
|
||||
|
|
@ -1623,7 +1603,7 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
|
|||
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16;
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
|
||||
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
||||
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
||||
|
||||
|
|
@ -1641,8 +1621,8 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
|
|||
hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
|
||||
|
||||
// Divide input by the scale
|
||||
HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
|
||||
HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
|
||||
HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
|
||||
HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
|
||||
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
|
||||
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
|
||||
|
||||
|
|
@ -1654,7 +1634,7 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
|
|||
*(HVX_Vector *) y_q = vx_i8;
|
||||
}
|
||||
|
||||
static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
||||
static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
||||
assert((unsigned long) x % 128 == 0);
|
||||
assert((unsigned long) y_q % 128 == 0);
|
||||
|
||||
|
|
@ -1672,11 +1652,11 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
|
|||
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
||||
|
||||
// Compute max and scale
|
||||
HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
|
||||
HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf));
|
||||
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
|
||||
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
|
||||
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
||||
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
||||
|
||||
|
|
@ -1689,8 +1669,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
|
|||
hvx_vec_store_u(y_d + 4, 4, vd23_hf);
|
||||
|
||||
// Divide input by the scale
|
||||
HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
|
||||
HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
|
||||
HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
|
||||
HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
|
||||
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
|
||||
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
|
||||
|
||||
|
|
@ -1702,7 +1682,7 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
|
|||
*(HVX_Vector *) y_q = vx_i8;
|
||||
}
|
||||
|
||||
static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
||||
static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
|
||||
assert((unsigned long) x % 128 == 0);
|
||||
assert((unsigned long) y_q % 128 == 0);
|
||||
|
||||
|
|
@ -1720,11 +1700,11 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
|
|||
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
||||
|
||||
// Compute max and scale
|
||||
HVX_Vector vmax_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
|
||||
vmax_hf = hvx_vec_reduce_max2_fp16(hvx_vec_abs_fp16(vx23_hf), vmax_hf);
|
||||
HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
|
||||
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
|
||||
vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
|
||||
|
||||
HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
|
|
@ -1733,7 +1713,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
|
|||
*(HVX_UVector *) y_d = vd_hf;
|
||||
|
||||
// Divide input by the scale
|
||||
HVX_Vector vd_inv_hf = hvx_vec_inverse_fp16(vd_hf);
|
||||
HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf);
|
||||
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
|
||||
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
|
||||
|
||||
|
|
@ -1746,7 +1726,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
|
|||
}
|
||||
|
||||
// Overrides input x
|
||||
static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
|
||||
static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
|
||||
assert(k % 32 == 0);
|
||||
const uint32_t qk = QK_Q8_0x4x2;
|
||||
const uint32_t nb = (k + qk - 1) / qk;
|
||||
|
|
@ -1764,24 +1744,24 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u
|
|||
|
||||
for (uint32_t i = 0; i < nb; i++) {
|
||||
#if FP32_QUANTIZE_GROUP_SIZE == 32
|
||||
quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
||||
quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
||||
quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
||||
quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
||||
#elif FP32_QUANTIZE_GROUP_SIZE == 64
|
||||
quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
||||
quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
||||
quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
||||
quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
||||
#elif FP32_QUANTIZE_GROUP_SIZE == 128
|
||||
quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
||||
quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
||||
quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
|
||||
quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
|
||||
#else
|
||||
#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
|
||||
#endif
|
||||
}
|
||||
|
||||
// now copy the scales into final location
|
||||
hvx_copy_fp16_ua(y_d, t_d, nb * 8);
|
||||
hvx_copy_f16_ua(y_d, t_d, nb * 8);
|
||||
}
|
||||
|
||||
static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
|
||||
static void quantize_f32_q8x4x2(const struct htp_tensor * src,
|
||||
uint8_t * restrict dst,
|
||||
struct htp_spad * spad,
|
||||
uint32_t nth,
|
||||
|
|
@ -1807,26 +1787,26 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
|
|||
uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
|
||||
uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
|
||||
|
||||
const size_t src_row_size_padded = htp_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding
|
||||
|
||||
for (uint32_t i = ir_first; i < ir_last; ++i) {
|
||||
htp_l2fetch(src_data, 2, src_row_size, src_row_size);
|
||||
hvx_copy_fp32_aa(tmp_data, src_data, ne0);
|
||||
hex_l2fetch(src_data, src_row_size, src_row_size, 2);
|
||||
hvx_copy_f32_aa(tmp_data, src_data, ne0);
|
||||
|
||||
// FARF(HIGH, "quantize-q8x4-row: %u\n", i);
|
||||
quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0);
|
||||
quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0);
|
||||
dst_data += dst_row_size;
|
||||
src_data += src_row_size;
|
||||
}
|
||||
|
||||
uint64_t t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "quantize-fp32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
|
||||
static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
|
||||
uint32_t nrows_per_thread, uint32_t dst_stride) {
|
||||
|
||||
uint64_t t1 = HAP_perf_get_qtimer_count();
|
||||
|
|
@ -1848,8 +1828,8 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict
|
|||
uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
|
||||
|
||||
for (uint32_t i = ir_first; i < ir_last; ++i) {
|
||||
htp_l2fetch(src_data, 2, src_row_size, src_stride);
|
||||
hvx_copy_fp16_fp32_au(dst_data, src_data, ne0);
|
||||
hex_l2fetch(src_data, src_row_size, src_stride, 2);
|
||||
hvx_copy_f16_f32_au(dst_data, src_data, ne0);
|
||||
|
||||
dst_data += dst_stride;
|
||||
src_data += src_stride;
|
||||
|
|
@ -1857,12 +1837,12 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict
|
|||
|
||||
uint64_t t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// TODO just a plain copy that should be done via the DMA during the Op setup
|
||||
static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
|
||||
static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
|
||||
uint32_t nrows_per_thread, uint32_t dst_stride) {
|
||||
|
||||
uint64_t t1 = HAP_perf_get_qtimer_count();
|
||||
|
|
@ -1884,8 +1864,8 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict
|
|||
uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
|
||||
|
||||
for (uint32_t i = ir_first; i < ir_last; ++i) {
|
||||
htp_l2fetch(src_data, 2, src_row_size, src_stride);
|
||||
hvx_copy_fp16_au(dst_data, src_data, ne0);
|
||||
hex_l2fetch(src_data, src_row_size, src_stride, 2);
|
||||
hvx_copy_f16_au(dst_data, src_data, ne0);
|
||||
|
||||
dst_data += dst_stride;
|
||||
src_data += src_stride;
|
||||
|
|
@ -1893,23 +1873,23 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict
|
|||
|
||||
uint64_t t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
||||
static void htp_quantize_f32_q8x4x2(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = data;
|
||||
quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
|
||||
quantize_f32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
|
||||
}
|
||||
|
||||
static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) {
|
||||
static void htp_quantize_f32_f16(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = data;
|
||||
quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
|
||||
quantize_f32_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
|
||||
}
|
||||
|
||||
static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) {
|
||||
static void htp_quantize_f16_f16(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = data;
|
||||
quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
|
||||
quantize_f16_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
|
||||
}
|
||||
|
||||
// ** matmul/matvec callbacks for worker_pool
|
||||
|
|
@ -2108,7 +2088,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
const size_t dst_row_size = nb1;
|
||||
size_t src1_row_size = nb11;
|
||||
|
||||
const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
|
||||
const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
|
||||
size_t src1_row_size_padded;
|
||||
|
||||
worker_callback_t quant_job_func;
|
||||
|
|
@ -2118,8 +2098,8 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
|
||||
switch (src0->type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
op_type = "q4x4x2-fp32";
|
||||
quant_job_func = htp_quantize_fp32_q8x4x2;
|
||||
op_type = "q4x4x2-f32";
|
||||
quant_job_func = htp_quantize_f32_q8x4x2;
|
||||
if (src1_nrows > 1) {
|
||||
matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2;
|
||||
} else {
|
||||
|
|
@ -2131,12 +2111,12 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
// Entire src1 tensor is placed into the VTCM
|
||||
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
||||
|
||||
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
||||
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
|
@ -2147,8 +2127,8 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
break;
|
||||
|
||||
case HTP_TYPE_Q8_0:
|
||||
op_type = "q8x4x2-fp32";
|
||||
quant_job_func = htp_quantize_fp32_q8x4x2;
|
||||
op_type = "q8x4x2-f32";
|
||||
quant_job_func = htp_quantize_f32_q8x4x2;
|
||||
if (src1_nrows > 1) {
|
||||
matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
|
||||
} else {
|
||||
|
|
@ -2160,12 +2140,12 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
// Entire src1 tensor is placed into the VTCM
|
||||
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
||||
|
||||
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
||||
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
|
@ -2177,7 +2157,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
|
||||
case HTP_TYPE_MXFP4:
|
||||
op_type = "mxfp4x4x2-f32";
|
||||
quant_job_func = htp_quantize_fp32_q8x4x2;
|
||||
quant_job_func = htp_quantize_f32_q8x4x2;
|
||||
if (src1_nrows > 1) {
|
||||
matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2;
|
||||
} else {
|
||||
|
|
@ -2189,12 +2169,12 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
// Entire src1 tensor is placed into the VTCM
|
||||
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
||||
|
||||
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
||||
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
|
@ -2207,10 +2187,10 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
case HTP_TYPE_F16:
|
||||
{
|
||||
// Try optimized f16-f16 path first (src1 in VTCM)
|
||||
const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128);
|
||||
const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256);
|
||||
const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
|
||||
const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
|
||||
const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128);
|
||||
const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
|
||||
const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
|
||||
const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
|
||||
|
||||
const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
|
||||
|
||||
|
|
@ -2222,7 +2202,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
|
||||
// Optimized path
|
||||
op_type = "f16-f16";
|
||||
quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16;
|
||||
quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_f32_f16 : htp_quantize_f16_f16;
|
||||
if (src1_nrows > 1) {
|
||||
matmul_job_func = htp_matmul_2d_f16_f16;
|
||||
} else {
|
||||
|
|
@ -2231,9 +2211,9 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
|
||||
src1_row_size = f16_src1_row_size; // row size post quantization
|
||||
|
||||
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
||||
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread;
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
||||
|
|
@ -2251,9 +2231,9 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
|
||||
src1_row_size = nb11; // original row size in DDR
|
||||
|
||||
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
|
||||
octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
|
||||
|
|
@ -2332,7 +2312,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
|
||||
const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
|
||||
|
||||
const uint32_t src0_nrows = ne01; // per expert
|
||||
const uint32_t src1_nrows = ne11 * ne12 * ne13;
|
||||
|
|
@ -2350,7 +2330,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|||
switch (src0->type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
op_type = "q4x2x2-f32";
|
||||
quant_job_func = htp_quantize_fp32_q8x4x2;
|
||||
quant_job_func = htp_quantize_f32_q8x4x2;
|
||||
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
|
||||
if (src1_nrows > 1) {
|
||||
matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2;
|
||||
|
|
@ -2360,13 +2340,13 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|||
|
||||
// Entire src1 tensor is placed into the VTCM
|
||||
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
||||
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
|
@ -2379,7 +2359,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|||
|
||||
case HTP_TYPE_Q8_0:
|
||||
op_type = "q8x2x2-f32";
|
||||
quant_job_func = htp_quantize_fp32_q8x4x2;
|
||||
quant_job_func = htp_quantize_f32_q8x4x2;
|
||||
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
|
||||
if (src1_nrows > 1) {
|
||||
matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2;
|
||||
|
|
@ -2389,13 +2369,13 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|||
|
||||
// Entire src1 tensor is placed into the VTCM
|
||||
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
||||
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
|
@ -2408,7 +2388,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|||
|
||||
case HTP_TYPE_MXFP4:
|
||||
op_type = "mxfp4x2x2-f32";
|
||||
quant_job_func = htp_quantize_fp32_q8x4x2;
|
||||
quant_job_func = htp_quantize_f32_q8x4x2;
|
||||
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
|
||||
if (src1_nrows > 1) {
|
||||
matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2;
|
||||
|
|
@ -2418,13 +2398,13 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|||
|
||||
// Entire src1 tensor is placed into the VTCM
|
||||
// For other tensors we allocate N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
|
||||
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
|
||||
octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
|
||||
if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
|
||||
octx->src0_spad.size_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,149 +0,0 @@
|
|||
#ifndef OPS_UTILS_H
|
||||
#define OPS_UTILS_H
|
||||
|
||||
#include "htp-msg.h"
|
||||
|
||||
#ifndef MAX
|
||||
# define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
# define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
static inline uint64_t htp_get_cycles() {
|
||||
uint64_t cycles = 0;
|
||||
asm volatile(" %0 = c15:14\n" : "=r"(cycles));
|
||||
return cycles;
|
||||
}
|
||||
|
||||
static inline uint64_t htp_get_pktcnt() {
|
||||
uint64_t pktcnt;
|
||||
asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
|
||||
return pktcnt;
|
||||
}
|
||||
|
||||
static inline int32_t htp_is_aligned(void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
// and a shift:
|
||||
//
|
||||
// n/d = (mulhi(n, mp) + n) >> L;
|
||||
struct fastdiv_values {
|
||||
uint32_t mp;
|
||||
uint32_t l;
|
||||
};
|
||||
|
||||
static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
|
||||
struct fastdiv_values result = { 0, 0 };
|
||||
// compute L = ceil(log2(d));
|
||||
while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
|
||||
++(result.l);
|
||||
}
|
||||
|
||||
result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
|
||||
// Compute high 32 bits of n * mp
|
||||
const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp)
|
||||
// add n, apply bit shift
|
||||
return (hi + n) >> vals->l;
|
||||
}
|
||||
|
||||
static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
|
||||
return n - fastdiv(n, vals) * d;
|
||||
}
|
||||
|
||||
static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));
|
||||
}
|
||||
|
||||
static inline int32_t htp_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static inline void htp_dump_int8_line(char * pref, const int8_t * x, int n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n && p < p_end; i++) {
|
||||
p += snprintf(p, p_end - p, "%d, ", x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n && p < p_end; i++) {
|
||||
p += snprintf(p, p_end - p, "%d, ", x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += snprintf(p, p_end - p, "%d, ", (int) x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_fp16_line(char * pref, const __fp16 * x, uint32_t n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_fp32_line(char * pref, const float * x, uint32_t n) {
|
||||
char str[1024], *p = str, *p_end = str + sizeof(str);
|
||||
p += snprintf(p, p_end - p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += snprintf(p, p_end - p, "%.6f, ", x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_f32(char * pref, const float * x, uint32_t n) {
|
||||
uint32_t n0 = n / 16;
|
||||
uint32_t n1 = n % 16;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < n0; i++) {
|
||||
htp_dump_fp32_line(pref, x + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
htp_dump_fp32_line(pref, x + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void htp_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
|
||||
uint32_t n0 = n / 16;
|
||||
uint32_t n1 = n % 16;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < n0; i++) {
|
||||
htp_dump_fp16_line(pref, x + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
htp_dump_fp16_line(pref, x + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* OPS_UTILS_H */
|
||||
|
|
@ -2,27 +2,20 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
|
||||
#define HTP_ROPE_TYPE_NORMAL 0
|
||||
|
|
@ -370,8 +363,8 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int
|
|||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == hex_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
|
||||
is_aligned = 0;
|
||||
}
|
||||
|
|
@ -427,9 +420,9 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
|||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
|
|
|
|||
|
|
@ -2,24 +2,20 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define set_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
|
|
@ -76,7 +72,7 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
|||
const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
// copy row
|
||||
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -112,7 +108,7 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
|
|||
const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00);
|
||||
hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,27 +2,20 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define htp_softmax_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
|
|
@ -100,8 +93,8 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
|
|||
uint8_t * restrict dst_curr = dst;
|
||||
const uint8_t * restrict mask_curr = mask;
|
||||
|
||||
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
|
||||
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
|
||||
HVX_Vector scale_vec = hvx_vec_splat_f32(scale);
|
||||
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
||||
|
||||
int step_of_1 = num_elems >> 5;
|
||||
|
||||
|
|
@ -134,9 +127,9 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
|||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector max_vec = hvx_vec_splat_fp32(((const float *) src)[0]);
|
||||
HVX_Vector max_vec = hvx_vec_splat_f32(((const float *) src)[0]);
|
||||
HVX_Vector zero_v = Q6_V_vzero();
|
||||
HVX_Vector one_v = hvx_vec_splat_fp32(1.0);
|
||||
HVX_Vector one_v = hvx_vec_splat_f32(1.0);
|
||||
|
||||
int step_of_1 = num_elems >> 5;
|
||||
|
||||
|
|
@ -146,7 +139,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
|||
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
|
||||
}
|
||||
|
||||
HVX_Vector v = hvx_vec_reduce_max_fp32(max_vec);
|
||||
HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
|
||||
max_vec = hvx_vec_repl4(v);
|
||||
|
||||
#pragma unroll(4)
|
||||
|
|
@ -154,18 +147,18 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
|||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec);
|
||||
|
||||
HVX_Vector v3 = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(v2));
|
||||
HVX_Vector v3 = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(v2));
|
||||
|
||||
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3);
|
||||
|
||||
v_pad[i] = v3;
|
||||
}
|
||||
|
||||
v = hvx_vec_qf32_reduce_sum(sum_vec);
|
||||
v = hvx_vec_reduce_sum_qf32(sum_vec);
|
||||
sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v));
|
||||
|
||||
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
|
||||
HVX_Vector v4 = hvx_vec_inverse_fp32(sum_vec);
|
||||
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
|
||||
HVX_Vector scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v);
|
||||
|
||||
#pragma unroll(4)
|
||||
|
|
@ -181,11 +174,11 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
|
|||
uint8_t * restrict spad,
|
||||
const int num_elems,
|
||||
const float max) {
|
||||
hvx_sub_scalar_f32(src, max, spad, num_elems);
|
||||
hvx_sub_scalar_f32(spad, src, max, num_elems);
|
||||
|
||||
hvx_exp_f32(spad, dst, num_elems, false);
|
||||
|
||||
float sum = hvx_self_sum_f32(dst, num_elems);
|
||||
float sum = hvx_reduce_sum_f32(dst, num_elems);
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
|
@ -255,7 +248,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
|
|||
if (1 == opt_path) {
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
|
||||
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
|
|
@ -290,7 +283,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
|
|||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
}
|
||||
|
|
@ -345,9 +338,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
|||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
|
|
|
|||
|
|
@ -2,28 +2,20 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define htp_unary_preamble \
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
|
|
@ -55,7 +47,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon);
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
||||
|
||||
int step_of_1 = num_elems >> 5;
|
||||
#pragma unroll(4)
|
||||
|
|
@ -65,15 +57,15 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v);
|
||||
HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v);
|
||||
sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_fp32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_fp32(t_v);
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
||||
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
||||
|
||||
HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
||||
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
|
|
@ -101,7 +93,7 @@ static void scale_htp_f32(const float * restrict src,
|
|||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
|
|
@ -124,7 +116,7 @@ static void rms_norm_htp_f32(const float * restrict src,
|
|||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
|
|
@ -168,9 +160,8 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
|||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n");
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
|
|
@ -240,8 +231,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,6 @@
|
|||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
|
||||
#include "HAP_farf.h"
|
||||
|
||||
#define WORKER_THREAD_STACK_SZ (2 * 16384)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ branch=.
|
|||
adbserial=
|
||||
[ "$S" != "" ] && adbserial="-s $S"
|
||||
|
||||
adbhost=
|
||||
[ "$H" != "" ] && adbhost="-H $H"
|
||||
|
||||
model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
[ "$M" != "" ] && model="$M"
|
||||
|
||||
|
|
@ -34,13 +37,16 @@ nhvx=
|
|||
ndev=
|
||||
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
|
||||
|
||||
hb=
|
||||
[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial shell " \
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$ndev $nhvx $opmask $verbose $experimental $profile ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
$ndev $nhvx $opmask $verbose $experimental $profile $hb ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--batch-size 128 -ngl 99 $cli_opts $@ \
|
||||
"
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ branch=.
|
|||
adbserial=
|
||||
[ "$S" != "" ] && adbserial="-s $S"
|
||||
|
||||
adbhost=
|
||||
[ "$H" != "" ] && adbhost="-H $H"
|
||||
|
||||
model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
[ "$M" != "" ] && model="$M"
|
||||
|
||||
|
|
@ -39,13 +42,16 @@ nhvx=
|
|||
ndev=
|
||||
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
|
||||
|
||||
hb=
|
||||
[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial shell " \
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
|
||||
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -fa on \
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ branch=.
|
|||
adbserial=
|
||||
[ "$S" != "" ] && adbserial="-s $S"
|
||||
|
||||
adbhost=
|
||||
[ "$H" != "" ] && adbhost="-H $H"
|
||||
|
||||
model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
[ "$M" != "" ] && model="$M"
|
||||
|
||||
|
|
@ -39,13 +42,16 @@ nhvx=
|
|||
ndev=
|
||||
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
|
||||
|
||||
hb=
|
||||
[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial shell " \
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -fa on \
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ branch=.
|
|||
adbserial=
|
||||
[ "$S" != "" ] && adbserial="-s $S"
|
||||
|
||||
adbhost=
|
||||
[ "$H" != "" ] && adbhost="-H $H"
|
||||
|
||||
model="gemma-3-4b-it-Q4_0.gguf"
|
||||
[ "$M" != "" ] && model="$M"
|
||||
|
||||
|
|
@ -51,7 +54,7 @@ mtmd_backend=
|
|||
|
||||
set -x
|
||||
|
||||
adb $adbserial shell " \
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ branch=.
|
|||
adbserial=
|
||||
[ "$S" != "" ] && adbserial="-s $S"
|
||||
|
||||
adbhost=
|
||||
[ "$H" != "" ] && adbhost="-H $H"
|
||||
|
||||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
|
|
@ -19,7 +22,7 @@ verbose=
|
|||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$V"
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
sched=
|
||||
[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v"
|
||||
|
|
@ -43,7 +46,7 @@ set -x
|
|||
|
||||
tool=$1; shift
|
||||
|
||||
adb $adbserial shell " \
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
|
|
|
|||
Loading…
Reference in New Issue