hexagon refactor all Ops to use local context struct (#19819)

* hexagon: refactor set/get/sum-rows ops to use local context

* hexagon: refactor ROPE and Softmax Ops to use local context

Improves performance a bit by precomputing things and saving in the context.

* hexagon: refactor activation ops to use local context struct

* hexagon: refactor unary ops to use local context struct and DMA/VTCM

* hexagon: use aligned hvx_scale function

* hexagon: remove unused fields from op_context

* hexagon: rewrite ROPE to use DMA and VTCM scratchpad

* hex-rope: keep N rows in scratchpad (instead of just two)

* hex-rope: introduce rowidx cache

* hex-rope: remove unused fields

* hex-rope: rewrite dma prefetch logic to allow for multi-row fetch/compute

also removes the need for fastdiv.

* hex-rope: minor formatting

* hex-rope: use indices and unroll the loops

* hex-rope: more updates to cleanup rope-block handling

* hexagon: cleanup supported type/dims checks

* hexagon: all reduce funcs replicated across lanes

There is no need to explicitly replicate the first value.

* snapdragon: update adb and windows scripts to use ubatch-size 256

Updated Ops support handles larger ubatches.
This commit is contained in:
Max Krasnyansky 2026-02-23 16:32:14 -08:00 committed by GitHub
parent 5eb0ea32f0
commit 39fb81f875
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 977 additions and 1000 deletions

View File

@ -1749,23 +1749,6 @@ static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backe
return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
}
static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
if (x->ne[0] != y->ne[0]) {
return false;
}
if (x->ne[1] != y->ne[1]) {
return false;
}
if (x->ne[2] != y->ne[2]) {
return false;
}
if (x->ne[3] != y->ne[3]) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
@ -1797,43 +1780,6 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
return opt_experimental;
}
static bool hex_supported_src0_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
static bool hex_supported_src1_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
static bool hex_supported_src2_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
static bool hex_supported_src1_type2(ggml_type t) {
return t == GGML_TYPE_F16;
}
static bool hex_supported_src1_type3(ggml_type t) {
return t == GGML_TYPE_I32;
}
static bool hex_supported_dst_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
// TODO: support broadcast for ne[2 and 3]
if (x->ne[0] != y->ne[0]) {
return false;
}
if (x->ne[2] != y->ne[2]) {
return false;
}
if (x->ne[3] != y->ne[3]) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
@ -1919,19 +1865,19 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_src1_type(src1->type)) {
if (src1->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dims2(src0, dst)) {
if (!ggml_are_same_shape(src0, dst)) {
return false;
}
if (!ggml_can_repeat(src1, src0)) {
if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) {
return false;
}
@ -1943,16 +1889,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_src1_type(src1->type)) {
if (src1->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dims2(src0, dst)) {
if (!ggml_are_same_shape(src0, dst)) {
return false;
}
@ -1968,13 +1914,13 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dims2(src0, dst)) {
if (!ggml_are_same_shape(src0, dst)) {
return false;
}
@ -1990,10 +1936,10 @@ static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session *
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
@ -2011,10 +1957,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
@ -2023,10 +1969,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
}
if (src1) {
if (!hex_supported_src1_type(src1->type)) {
if (src1->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dims2(src0, src1)) {
if (!ggml_are_same_shape(src0, src1)) {
return false;
}
if (!ggml_is_contiguous(src1)) {
@ -2047,15 +1993,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
return false; // FIXME: add support for sinks
}
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (src1) {
if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
return false;
}
if (src0->ne[0] != src1->ne[0]) {
@ -2162,17 +2108,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
const struct ggml_tensor * src2 = op->src[2];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false; // FIXME: add support for GGML_TYPE_F16 for src0
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_src1_type3(src1->type)) {
if (src1->type != GGML_TYPE_I32) {
return false;
}
if (src2) {
if (!hex_supported_src2_type(src2->type)) {
if (src2->type != GGML_TYPE_F32) {
return false;
}
int n_dims = op_params[1];

View File

@ -69,27 +69,45 @@
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
struct htp_act_context {
struct htp_ops_context * octx;
// Precomputed values
const uint8_t * data_src0;
const uint8_t * data_src1;
uint8_t * data_dst;
size_t src0_row_size;
size_t src1_row_size;
size_t dst_row_size;
size_t src0_row_size_aligned;
size_t src1_row_size_aligned;
size_t dst_row_size_aligned;
size_t src0_spad_half_size;
size_t src1_spad_half_size;
size_t dst_spad_half_size;
uint32_t block;
uint32_t src0_nrows;
uint32_t src0_nrows_per_thread;
int nc;
};
static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * src1 = &actx->octx->src1;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble3;
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
size_t src0_row_size = actx->src0_row_size;
size_t src1_row_size = actx->src1_row_size;
size_t dst_row_size = actx->dst_row_size;
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -101,43 +119,34 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const uint8_t * restrict data_src0 = actx->data_src0;
const uint8_t * restrict data_src1 = actx->data_src1;
uint8_t * restrict data_dst = actx->data_dst;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const int nc = actx->nc;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t src1_spad_half_size = actx->src1_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@ -196,27 +205,22 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * src1 = &actx->octx->src1;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble3;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
size_t src0_row_size = actx->src0_row_size;
size_t src1_row_size = actx->src1_row_size;
size_t dst_row_size = actx->dst_row_size;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -226,45 +230,36 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
return;
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const uint8_t * restrict data_src0 = actx->data_src0;
const uint8_t * restrict data_src1 = actx->data_src1;
uint8_t * restrict data_dst = actx->data_dst;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const int nc = actx->nc;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t src1_spad_half_size = actx->src1_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
"%zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
const float alpha = ((const float *) (op_params))[2];
const float limit = ((const float *) (op_params))[3];
const float alpha = ((const float *) (actx->octx->op_params))[2];
const float limit = ((const float *) (actx->octx->op_params))[3];
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
@ -335,26 +330,22 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
}
static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble2;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const size_t src0_row_size = actx->src0_row_size;
const size_t dst_row_size = actx->dst_row_size;
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -364,25 +355,29 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
return;
}
const uint8_t * data_src0 = (const uint8_t *) src0->data;
uint8_t * data_dst = (uint8_t *) dst->data;
const uint8_t * data_src0 = actx->data_src0;
uint8_t * data_dst = actx->data_dst;
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
// nc/ne0 matches.
const int ne0_val = actx->nc; // == dst->ne[0]
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
// In gelu = x*sigmoid(x*1.702)
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@ -408,9 +403,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// gelu = x * sigmoid(1.702 * x) // current implementation
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@ -435,34 +430,23 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble2;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const size_t src0_row_size = actx->src0_row_size;
const size_t dst_row_size = actx->dst_row_size;
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -472,24 +456,27 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
return;
}
const uint8_t * data_src0 = (const uint8_t *) src0->data;
uint8_t * data_dst = (uint8_t *) dst->data;
const uint8_t * data_src0 = actx->data_src0;
uint8_t * data_dst = actx->data_dst;
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
const int ne0_val = actx->nc; // == dst->ne[0]
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@ -515,8 +502,8 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// silu = x * sigmoid(x)
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@ -544,27 +531,22 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
static const float GELU_COEF_A = 0.044715f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * src1 = &actx->octx->src1;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble3;
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
size_t src0_row_size = actx->src0_row_size;
size_t src1_row_size = actx->src1_row_size;
size_t dst_row_size = actx->dst_row_size;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -574,43 +556,34 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
return;
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const uint8_t * restrict data_src0 = actx->data_src0;
const uint8_t * restrict data_src1 = actx->data_src1;
uint8_t * restrict data_dst = actx->data_dst;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const int nc = actx->nc;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t src1_spad_half_size = actx->src1_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR,
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@ -678,33 +651,7 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static int execute_op_activations_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
@ -719,26 +666,26 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
switch (octx->op) {
case HTP_OP_UNARY_SILU:
act_op_func = unary_silu_f32;
act_op_func = (worker_callback_t)unary_silu_f32_per_thread;
op_type = "silu-f32";
break;
case HTP_OP_GLU_SWIGLU:
act_op_func = glu_swiglu_f32;
act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;
op_type = "swiglu-f32";
break;
case HTP_OP_GLU_SWIGLU_OAI:
act_op_func = glu_swiglu_oai_f32;
act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;
op_type = "swiglu-oai-f32";
break;
case HTP_OP_UNARY_GELU:
act_op_func = unary_gelu_f32;
act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;
op_type = "gelu-f32";
break;
case HTP_OP_GLU_GEGLU:
act_op_func = glu_geglu_f32;
act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;
op_type = "geglu-f32";
break;
default:
@ -797,13 +744,58 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
}
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
return HTP_STATUS_OK;
}
return err;
uint32_t n_jobs = MIN(n_threads, src0_nrows);
// Prepare context
struct htp_act_context actx;
actx.octx = octx;
actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
actx.src0_row_size = src0_row_size;
actx.src1_row_size = src1_row_size;
actx.dst_row_size = dst_row_size;
actx.src0_row_size_aligned = src0_row_size_aligned;
actx.src1_row_size_aligned = src1_row_size_aligned;
actx.dst_row_size_aligned = dst_row_size_aligned;
actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;
actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;
actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2;
actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;
actx.src0_nrows = src0_nrows;
actx.nc = dst->ne[0];
// Pointers and GLU logic
const uint8_t * data_src0 = (const uint8_t *) src0->data;
const uint8_t * data_src1 = (const uint8_t *) src1->data;
if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
const int32_t swapped = octx->op_params[1];
data_src1 = data_src0;
actx.src1_row_size = actx.src0_row_size;
size_t nc_in_bytes = actx.nc * SIZEOF_FP32;
if (swapped) {
data_src0 += nc_in_bytes;
} else {
data_src1 += nc_in_bytes;
}
}
actx.data_src0 = data_src0;
actx.data_src1 = data_src1;
actx.data_dst = (uint8_t *) dst->data;
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
return HTP_STATUS_OK;
}
int op_activations(struct htp_ops_context * octx) {

View File

@ -15,6 +15,13 @@
#include "htp-ops.h"
#include "hvx-utils.h"
struct get_rows_context {
struct htp_ops_context * octx;
uint32_t src1_nrows_per_thread;
struct fastdiv_values get_rows_div_ne10;
struct fastdiv_values get_rows_div_ne10_ne11;
};
#define get_rows_preamble \
const uint32_t ne00 = octx->src0.ne[0]; \
const uint32_t ne01 = octx->src0.ne[1]; \
@ -39,20 +46,22 @@
\
const uint32_t nr = ne10 * ne11 * ne12;
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
struct get_rows_context * grctx = (struct get_rows_context *)data;
struct htp_ops_context * octx = grctx->octx;
get_rows_preamble;
// parallelize by src1 elements (which correspond to dst rows)
const uint32_t dr = octx->src1_nrows_per_thread;
const uint32_t dr = grctx->src1_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
const uint32_t rem = i - i12 * ne11 * ne10;
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
const uint32_t i10 = rem - i11 * ne10;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@ -68,12 +77,6 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
}
return HTP_STATUS_OK;
}
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_get_rows(struct htp_ops_context * octx) {
@ -95,12 +98,14 @@ int op_get_rows(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
struct get_rows_context grctx;
grctx.octx = octx;
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
return HTP_STATUS_OK;
}

View File

@ -102,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q,
dmlink(q->tail, desc);
q->tail = desc;
// FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
// FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
@ -144,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
dptr = q->dptr[q->pop_idx];
// FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
// FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
return dptr;
}
static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) {
dma_ptr dptr = { NULL };
if (q->push_idx == q->pop_idx) {
return dptr;
}
dptr = q->dptr[q->pop_idx];
// FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
return dptr;
}
static inline bool dma_queue_empty(dma_queue * q) {
return q->push_idx == q->pop_idx;
}
static inline uint32_t dma_queue_depth(dma_queue * q) {
return (q->push_idx - q->pop_idx) & q->idx_mask;
}
static inline uint32_t dma_queue_capacity(dma_queue * q) {
return q->capacity;
}
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -44,32 +44,6 @@ struct htp_ops_context {
uint32_t src0_nrows_per_thread;
uint32_t src1_nrows_per_thread;
struct fastdiv_values src0_div1; // fastdiv values for ne1
struct fastdiv_values src0_div2; // fastdiv values for ne2
struct fastdiv_values src0_div3; // fastdiv values for ne3
struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values src1_div1; // fastdiv values for ne1
struct fastdiv_values src1_div2; // fastdiv values for ne2
struct fastdiv_values src1_div3; // fastdiv values for ne3
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values src3_div1; // fastdiv values for ne1
struct fastdiv_values src3_div2; // fastdiv values for ne2
struct fastdiv_values src3_div3; // fastdiv values for ne3
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
uint32_t flags;
};

View File

@ -49,62 +49,6 @@ struct htp_matmul_context {
struct fastdiv_values mm_div_r3;
};
// vdelta control to replicate first 4x fp32 values across lanes
static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
};
// vdelta control to replicate and interleave first 8x fp32 values across lanes
static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
};
// vdelta control to replicate first fp32 value across all elements
static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
// vdelta control to replicate first fp16 value across all elements
static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
// vdelta control to replicate first fp16 value across all elements
static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
@ -2067,10 +2011,10 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
// Convert to QF32
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
// Combine and convert to fp16
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
@ -2080,11 +2024,6 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
@ -2130,13 +2069,8 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Compute max and scale
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
@ -2179,11 +2113,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric
// Compute max and scale
HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);

View File

@ -10,6 +10,7 @@
#include "hex-dma.h"
#include "hvx-utils.h"
#include "hex-fastdiv.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
@ -21,6 +22,9 @@
#define HTP_ROPE_TYPE_NORMAL 0
#define HTP_ROPE_TYPE_NEOX 2
#define HTP_ROPE_SPAD_NROWS 16
#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
#define htp_rope_preamble \
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
@ -42,7 +46,7 @@
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
struct rope_th_ctx {
struct htp_rope_context {
int32_t n_dims;
int32_t mode;
int32_t n_ctx_orig;
@ -57,7 +61,19 @@ struct rope_th_ctx {
float theta_scale;
float corr_dims[2];
uint32_t src0_nrows_per_thread;
size_t spad_stride;
struct htp_ops_context * octx;
size_t src0_row_size;
size_t dst_row_size;
size_t src0_row_size_aligned;
size_t dst_row_size_aligned;
size_t theta_cache_offset;
uint32_t src0_nrows;
uint64_t t_start;
};
static float rope_yarn_ramp(const float low, const float high, const int i0) {
@ -117,64 +133,23 @@ static void rope_corr_dims(int n_dims,
dims[1] = MIN(n_dims - 1, end);
}
static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
const int32_t * op_params = &octx->op_params[0];
uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2
rope_ctx->n_dims = ((const int32_t *) op_params)[1];
rope_ctx->mode = ((const int32_t *) op_params)[2];
rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
uint32_t he = ne / 2; // half_dims offset in elements
uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors
memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
#pragma unroll(2)
for (uint32_t i = 0; i < nvec; i += 2) {
HVX_Vector v0 = vsrc[i/2+0];
HVX_Vector v1 = vsrc[i/2+hv];
rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
rope_ctx->beta_slow, rope_ctx->corr_dims);
rope_ctx->octx = octx;
FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
}
static void hvx_calc_rope_neox_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
// for (int i = 0; i < num_elems; i += 2) {
//const float cos_theta = theta_cache[i + 0];
//const float sin_theta = theta_cache[i + 1];
//const float x0 = src[0];
//const float x1 = src[num_elems/2];
//dst[0] = x0*cos_theta - x1*sin_theta;
//dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
//src += 1;
//dst += 1;
// }
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
uint8_t * restrict dst_curr = (uint8_t *) dst;
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
int half_size = (sizeof(float) * (num_elems / 2));
for (int i = 0; i < step_of_1; i++) {
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
HVX_Vector v2 = vtheta[i+0];
HVX_Vector v3 = vtheta[i+1];
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
@ -186,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4);
vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);
}
src0_curr += VLEN;
theta_curr += 2 * VLEN;
dst_curr += VLEN;
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
const float cos_theta = theta_cache[i+0];
const float sin_theta = theta_cache[i+1];
float x0 = src0[i/2];
float x1 = src0[i/2 + he];
dst[i/2] = x0 * cos_theta - x1 * sin_theta;
dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;
}
}
static void hvx_calc_rope_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
// for (int i = 0; i < num_elems; i += 2) {
//const float cos_theta = theta_cache[i + 0];
//const float sin_theta = theta_cache[i + 1];
static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
//const float x0 = src[0];
//const float x1 = src[1];
uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two
//dst[0] = x0*cos_theta - x1*sin_theta;
//dst[1] = x0*sin_theta + x1*cos_theta;
#pragma unroll(2)
for (uint32_t i = 0; i < nvec; i+=2) {
HVX_Vector v0 = vsrc[i+0];
HVX_Vector v1 = vsrc[i+1];
//src += 2;
//dst += 2;
// }
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
uint8_t * restrict dst_curr = (uint8_t *) dst;
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
for (int i = 0; i < step_of_1; i++) {
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
HVX_Vector v2 = vtheta[i+0];
HVX_Vector v3 = vtheta[i+1];
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
@ -239,116 +203,65 @@ static void hvx_calc_rope_f32(const float * restrict src0,
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
*(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore);
*(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
vdst[i+0] = Q6_V_lo_W(vstore);
vdst[i+1] = Q6_V_hi_W(vstore);
}
src0_curr += 2 * VLEN;
theta_curr += 2 * VLEN;
dst_curr += 2 * VLEN;
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
const float cos_theta = theta_cache[i+0];
const float sin_theta = theta_cache[i+1];
float x0 = src0[i+0];
float x1 = src0[i+1];
dst[i+0] = x0 * cos_theta - x1 * sin_theta;
dst[i+1] = x0 * sin_theta + x1 * cos_theta;
}
}
static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
const uint32_t ir0,
const uint32_t ir1,
int nth,
int ith,
const int opt_path) {
struct htp_ops_context * octx = rope_ctx->octx;
static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
#pragma unroll(4)
for (uint32_t i = 0; i < nr; i++) {
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache);
// fill the remain channels with data from src tensor
if (rctx->n_dims < ne0) {
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
}
}
}
static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
#pragma unroll(4)
for (uint32_t i = 0; i < nr; i++) {
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache);
// fill the remain channels with data from src tensor
if (rctx->n_dims < ne0) {
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
}
}
}
static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
struct htp_rope_context * rctx = (struct htp_rope_context *) data;
struct htp_ops_context * octx = rctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;
const int32_t mode = rope_ctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
htp_rope_preamble;
const int32_t * pos = (const int32_t *) src1->data;
float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
const float * freq_factors = NULL;
if (src2 != NULL) {
freq_factors = (const float *) src2->data;
}
const uint32_t i1_end = MIN(ir1, ne1);
const int32_t half_dims = rope_ctx->n_dims / 2;
const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
const int32_t p = pos[i2];
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
const float * src_loc = src;
float * dst_data_loc = dst_data;
if (1 == opt_path) {
if (is_neox) {
hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
} else {
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
}
src_loc += rope_ctx->n_dims;
dst_data_loc += rope_ctx->n_dims;
} else {
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
const float cos_theta = wp0[i0 + 0];
const float sin_theta = wp0[i0 + 1];
if (is_neox) {
const float x0 = src_loc[0];
const float x1 = src_loc[half_dims];
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
src_loc += 1;
dst_data_loc += 1;
} else {
const float x0 = src_loc[0];
const float x1 = src_loc[1];
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
src_loc += 2;
dst_data_loc += 2;
}
}
src_loc += (is_neox ? half_dims : 0);
dst_data_loc += (is_neox ? half_dims : 0);
}
// TODO: use simd to speed up the remaining elements copy
memcpy(dst_data_loc, src_loc, remain_bytes);
}
}
}
}
static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
struct htp_ops_context * octx = rope_ctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
htp_rope_preamble;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
const uint32_t src0_nrows = rctx->src0_nrows;
const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -358,32 +271,114 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int
return;
}
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
uint64_t tt = HAP_perf_get_qtimer_count();
int is_aligned = 1;
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
(0 == hex_is_aligned((void *) dst->data, VLEN))) {
FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
is_aligned = 0;
}
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
const int32_t mode = rctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
// VTCM setup
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
float * theta_cache = (float *) (src0_spad_base);
src0_spad_base = src0_spad_base + rctx->theta_cache_offset;
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
dma_queue * dma_queue = octx->ctx->dma[ith];
const int32_t * pos = (const int32_t *) src1->data;
const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
uint32_t ir = 0;
uint32_t prev_i2 = (uint32_t) -1;
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads
if (ir < src0_start_row) { ir++; i1++; continue; }
if (ir >= src0_end_row) goto done;
// Rows in this block
const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1);
// Depth before prefetch
uint32_t dma_depth = dma_queue_depth(dma_queue);
// FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth,
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
// Prefetch loop
for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) {
pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK);
uint32_t pi1 = i1 + pr;
uint32_t pir = ir + pr;
// Dummy DMA transaction for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0);
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned;
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
}
// Update theta cache
if (i2 != prev_i2) {
prev_i2 = i2;
const int32_t p = pos[i2];
rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
// FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
}
// Skip DMA transactions from prev block (if any)
// No need to wait for these since the DMA is setup for in-order processing
for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
// Compute loop
for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) {
// Number of rows to compute
cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK);
uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src;
uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst;
// FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr,
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
if (is_neox) {
rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
} else {
rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
}
uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1;
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr);
// Prefetch more rows (if any)
if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) {
uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK);
uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS;
uint32_t pir = ir + HTP_ROPE_SPAD_NROWS;
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
}
}
}
}
}
rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
done:
dma_queue_flush(dma_queue);
tt = HAP_perf_get_qtimer_count() - tt;
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
rope_job_f32_per_thread(rope_ctx, n, i);
FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt));
}
static int execute_op_rope_f32(struct htp_ops_context * octx) {
@ -394,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;
worker_callback_t op_func;
const char * op_type = NULL;
struct rope_th_ctx rope_ctx;
const char * op_type = "rope-f32";
switch (octx->op) {
case HTP_OP_ROPE:
op_func = rope_job_dispatcher_f32;
op_type = "rope-f32";
init_rope_ctx(&rope_ctx, octx);
break;
default:
@ -415,49 +403,79 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
const uint32_t n_threads = octx->n_threads;
const size_t src0_row_size = src0->nb[1];
const size_t src1_row_size = src0_row_size;
const size_t dst_row_size = dst->nb[1];
// VTCM scratchpads for all tensors
// N rows per thread, padded to HVX vector size
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
// Aligned row sizes for VTCM
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
// Calculate spad sizes per thread
size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;
size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned;
size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread;
if (src2->ne[0]) {
FARF(HIGH,
"%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
"dst-spad-size %u\n",
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
} else {
FARF(HIGH,
"%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
octx->dst_spad.size);
}
// Make sure the reserved vtcm size is sufficient
if (octx->ctx->vtcm_size < spad_size) {
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
spad_size);
// Check if we fit in VTCM
size_t total_vtcm_needed = spad_per_thread * n_threads;
if (octx->ctx->vtcm_size < total_vtcm_needed) {
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
// Assign sizes
octx->src0_spad.size_per_thread = src0_spad_per_thread;
octx->dst_spad.size_per_thread = dst_spad_per_thread;
octx->src0_spad.size = n_threads * src0_spad_per_thread;
octx->dst_spad.size = n_threads * dst_spad_per_thread;
octx->src1_spad.size = 0;
// Assign pointers
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src1_spad.data = NULL;
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
// Fill context
struct htp_rope_context rctx;
memset(&rctx, 0, sizeof(struct htp_rope_context));
rctx.t_start = HAP_perf_get_qtimer_count();
rctx.octx = octx;
const int32_t * op_params = &octx->op_params[0];
rctx.n_dims = ((const int32_t *) op_params)[1];
rctx.mode = ((const int32_t *) op_params)[2];
rctx.n_ctx_orig = ((const int32_t *) op_params)[4];
memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float));
memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float));
memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float));
memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float));
memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float));
memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float));
memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4);
rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims);
rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims);
rctx.src0_row_size = src0_row_size;
rctx.dst_row_size = dst_row_size;
rctx.src0_row_size_aligned = src0_row_size_aligned;
rctx.dst_row_size_aligned = dst_row_size_aligned;
rctx.theta_cache_offset = theta_cache_size_aligned;
uint32_t ne0 = dst->ne[0];
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
rctx.src0_nrows = src0_nrows;
FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
uint32_t n_jobs = MIN(n_threads, src0_nrows);
rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
}
return err;

View File

@ -43,11 +43,21 @@
\
const uint32_t nr = ne01;
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
struct htp_set_rows_context {
struct htp_ops_context * octx;
struct fastdiv_values div_ne12;
struct fastdiv_values div_ne11;
uint32_t src0_nrows_per_thread;
};
static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
struct htp_ops_context * octx = srctx->octx;
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t dr = srctx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
@ -56,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@ -76,15 +86,16 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
}
}
}
return HTP_STATUS_OK;
}
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
struct htp_ops_context * octx = srctx->octx;
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t dr = srctx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
@ -93,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@ -112,16 +123,6 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
}
}
}
return HTP_STATUS_OK;
}
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
}
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_set_rows(struct htp_ops_context * octx) {
@ -143,18 +144,20 @@ int op_set_rows(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
struct htp_set_rows_context srctx;
srctx.octx = octx;
srctx.div_ne12 = init_fastdiv_values(ne12);
srctx.div_ne11 = init_fastdiv_values(ne11);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
switch(octx->dst.type) {
case HTP_TYPE_F32:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
break;
case HTP_TYPE_F16:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
break;
default:
return HTP_STATUS_NO_SUPPORT;

View File

@ -10,6 +10,7 @@
#include "hex-dma.h"
#include "hvx-utils.h"
#include "hex-fastdiv.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
@ -48,7 +49,7 @@
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
struct softmax_th_ctx {
struct htp_softmax_context {
bool use_f16;
bool use_src1;
uint32_t n_head;
@ -59,28 +60,48 @@ struct softmax_th_ctx {
float m0;
float m1;
uint32_t src0_nrows_per_thread;
struct fastdiv_values fastdiv_ne01;
struct fastdiv_values fastdiv_ne02;
struct fastdiv_values fastdiv_ne12; // For mask broadcasting
struct fastdiv_values fastdiv_ne13; // For mask broadcasting
size_t spad_stride;
struct htp_ops_context * octx;
};
static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
memset(smctx, 0, sizeof(struct htp_softmax_context));
memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
softmax_ctx->n_head = src0->ne[2];
softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
smctx->n_head = src0->ne[2];
smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head));
softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
softmax_ctx->use_src1 = (src1->ne[0] != 0);
softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
smctx->use_src1 = (src1->ne[0] != 0);
smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
softmax_ctx->octx = octx;
smctx->octx = octx;
// Initialize fastdiv values
const uint32_t ne01 = src0->ne[1];
const uint32_t ne02 = src0->ne[2];
if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
}
static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
@ -139,8 +160,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
}
HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
max_vec = hvx_vec_repl4(v);
max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes
#pragma unroll(4)
for (int i = 0; i < step_of_1; i++) {
@ -154,8 +174,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
v_pad[i] = v3;
}
v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
sum_vec = hvx_vec_repl4(v);
sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
@ -183,83 +202,9 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
return sum;
}
static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
struct htp_ops_context * octx = softmax_ctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
const struct htp_tensor * dst = &octx->dst;
htp_softmax_preamble3;
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1);
float * wp0 = (float *) src0_spad_data;
float * wp1 = (float *) src1_spad_data;
float * wp2 = (float *) dst_spad_data;
for (uint32_t i03 = 0; i03 < ne03; i03++) {
for (uint32_t i02 = 0; i02 < ne02; i02++) {
for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
const uint32_t i11 = i01;
const uint32_t i12 = i02 % ne12;
const uint32_t i13 = i03 % ne13;
// ALiBi
const uint32_t h = i02; // head
const float slope = (softmax_ctx->max_bias > 0.0f) ?
h < softmax_ctx->n_head_log2 ?
powf(softmax_ctx->m0, h + 1) :
powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
1.0f;
float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
// broadcast the mask across rows
__fp16 * mp_f16 = (softmax_ctx->use_src1) ?
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
NULL;
float * mp_f32 = (softmax_ctx->use_src1) ?
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
NULL;
if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
(const uint8_t *) mp_f32, slope);
} else {
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
if (mp_f32) {
if (softmax_ctx->use_f16) {
for (int i = 0; i < ne00; ++i) {
wp0[i] += slope * (float) mp_f16[i];
}
} else {
for (int i = 0; i < ne00; ++i) {
wp0[i] += slope * mp_f32[i];
}
}
}
}
if (1 == opt_path) {
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
} else {
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
sum = sum > 0.0 ? (1.0 / sum) : 1;
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
}
}
}
}
}
static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
struct htp_ops_context * octx = softmax_ctx->octx;
static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
struct htp_ops_context * octx = smctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
@ -268,7 +213,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
htp_softmax_preamble3;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -291,20 +236,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
opt_path = 1;
}
softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride);
float * wp0 = (float *) src0_spad_data;
float * wp1 = (float *) src1_spad_data;
float * wp2 = (float *) dst_spad_data;
uint32_t prev_i2 = (uint32_t)-1;
float slope = 1.0f;
for (uint32_t r = src0_start_row; r < src0_end_row; ++r) {
uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01);
uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01);
uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02);
uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02);
// Map to original logic indices
// i01 = i1
// i02 = i2
// i03 = i3
const uint32_t i11 = i1;
// const uint32_t i12 = i2 % ne12;
// const uint32_t i13 = i3 % ne13;
uint32_t i12, i13;
if (ne12 == ne02) {
i12 = i2;
} else {
i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12);
}
if (ne13 == ne03) {
i13 = i3;
} else {
i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13);
}
// ALiBi
if (i2 != prev_i2) {
const uint32_t h = i2; // head
slope = (smctx->max_bias > 0.0f) ?
h < smctx->n_head_log2 ?
powf(smctx->m0, h + 1) :
powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
1.0f;
prev_i2 = i2;
}
float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
// broadcast the mask across rows
__fp16 * mp_f16 = (smctx->use_src1) ?
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
NULL;
float * mp_f32 = (smctx->use_src1) ?
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
NULL;
if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
(const uint8_t *) mp_f32, slope);
} else {
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
if (mp_f32) {
if (smctx->use_f16) {
for (int i = 0; i < ne00; ++i) {
wp0[i] += slope * (float) mp_f16[i];
}
} else {
for (int i = 0; i < ne00; ++i) {
wp0[i] += slope * mp_f32[i];
}
}
}
}
if (1 == opt_path) {
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
} else {
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
sum = sum > 0.0 ? (1.0 / sum) : 1;
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
}
}
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
softmax_job_f32_per_thread(p_softmax_ctx, n, i);
}
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
@ -312,17 +340,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
worker_callback_t op_func;
const char * op_type = NULL;
struct softmax_th_ctx softmax_ctx;
struct htp_softmax_context smctx;
const char * op_type = "softmax-f32";
switch (octx->op) {
case HTP_OP_SOFTMAX:
op_func = softmax_job_dispatcher_f32;
op_type = "softmax-f32";
init_softmax_ctx(&softmax_ctx, octx);
init_softmax_ctx(&smctx, octx);
break;
default:
@ -342,6 +365,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
// Use stride for calculating offset
smctx.spad_stride = hex_round_up(src0_row_size, 128);
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
if (src1->ne[0]) {
@ -371,8 +397,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
}
return err;

View File

@ -17,7 +17,6 @@
#include "htp-msg.h"
#include "htp-ops.h"
#define sum_rows_preamble \
struct htp_tensor *src0 = &octx->src0;\
struct htp_tensor *dst = &octx->dst; \
@ -42,53 +41,54 @@
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3]; \
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
sum_rows_preamble;
struct sum_rows_context {
const uint8_t * src_data;
uint8_t * dst_data;
uint32_t ne00;
size_t src_stride;
size_t dst_stride;
uint32_t rows_per_thread;
uint32_t total_rows;
bool opt_path;
};
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) {
const struct sum_rows_context * smctx = (const struct sum_rows_context *) data;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t rows_per_thread = smctx->rows_per_thread;
const uint32_t total_rows = smctx->total_rows;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
const uint32_t start_row = rows_per_thread * ith;
const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return HTP_STATUS_OK;
if (start_row >= end_row) {
return;
}
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
}
const size_t src_stride = smctx->src_stride;
const size_t dst_stride = smctx->dst_stride;
const uint32_t ne00 = smctx->ne00;
const bool opt_path = smctx->opt_path;
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride));
float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride));
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
// Calculate actual number of rows for this thread
const uint32_t n_rows = end_row - start_row;
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
const float * restrict src_local = src_th + (ir * ne00);
for (uint32_t ir = 0; ir < n_rows; ir++) {
const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float)));
if (ir + 1 < src0_nrows_per_thread) {
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
if (ir + 1 < n_rows) {
hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1);
}
if (1 == opt_path) {
if (opt_path) {
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
} else {
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
}
}
return HTP_STATUS_OK;
}
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
}
int op_sum_rows(struct htp_ops_context * octx) {
@ -106,10 +106,25 @@ int op_sum_rows(struct htp_ops_context * octx) {
const uint32_t src0_nrows = ne01 * ne02 * ne03;
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
bool opt_path = false;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
opt_path = true;
}
struct sum_rows_context smctx = {
.src_data = (const uint8_t *) src0->data,
.dst_data = (uint8_t *) dst->data,
.ne00 = ne00,
.src_stride = nb01,
.dst_stride = nb1,
.rows_per_thread = rows_per_thread,
.total_rows = src0_nrows,
.opt_path = opt_path,
};
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
return HTP_STATUS_OK;
}

View File

@ -17,6 +17,28 @@
#include "htp-msg.h"
#include "htp-ops.h"
struct htp_unary_context {
struct htp_ops_context * octx;
// Precomputed values
const uint8_t * data_src0;
uint8_t * data_dst;
size_t src0_row_size;
size_t dst_row_size;
size_t src0_row_size_aligned;
size_t dst_row_size_aligned;
size_t src0_spad_half_size;
size_t dst_spad_half_size;
uint32_t block;
uint32_t src0_nrows;
uint32_t src0_nrows_per_thread;
uint32_t nc;
};
#define htp_unary_preamble \
const uint32_t ne00 = src->ne[0]; \
const uint32_t ne01 = src->ne[1]; \
@ -57,8 +79,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
}
HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
sum_v = hvx_vec_repl4(reduced_sum);
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
@ -75,128 +96,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
}
}
static void scale_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
static void scale_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
float scale = 0.f;
float bias = 0.f;
memcpy(&scale, &op_params[0], sizeof(float));
memcpy(&bias, &op_params[1], sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
}
}
static void rms_norm_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
static void rms_norm_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
float epsilon = 0.f;
memcpy(&epsilon, op_params, sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
} else {
float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
const float mean = sum / row_elems;
const float scale = 1.0f / sqrtf(mean + epsilon);
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
}
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
}
}
static void sqr_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
static void sqr_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
static void sqrt_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
static void sqrt_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
static void unary_job_f32_per_thread(const struct htp_tensor * src,
struct htp_tensor * dst,
uint8_t * spad,
int htp_op,
int32_t * op_params,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread) {
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
struct htp_ops_context * octx = uctx->octx;
const struct htp_tensor * src = &octx->src0;
const struct htp_tensor * dst = &octx->dst;
htp_unary_preamble;
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
int htp_op = octx->op;
int32_t * op_params = octx->op_params;
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const size_t src0_row_size = uctx->src0_row_size;
const size_t dst_row_size = uctx->dst_row_size;
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
const uint32_t src0_nrows = uctx->src0_nrows;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -208,79 +196,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
int is_aligned = 1;
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
is_aligned = 0;
}
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
const uint8_t * restrict data_src = uctx->data_src0;
uint8_t * restrict data_dst = uctx->data_dst;
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half_size = uctx->src0_spad_half_size;
size_t dst_spad_half_size = uctx->dst_spad_half_size;
const int BLOCK = uctx->block;
if (BLOCK == 0) {
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
const uint8_t * restrict data_src = (const uint8_t *) src->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
dma_queue * dma_queue = octx->ctx->dma[ith];
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
switch (htp_op) {
case HTP_OP_RMS_NORM:
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SCALE:
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQR:
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQRT:
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
default:
break;
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
// Process block in VTCM
switch (htp_op) {
case HTP_OP_RMS_NORM:
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_SCALE:
scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_SQR:
sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_SQRT:
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
default:
break;
}
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
dst_row_size, dst_row_size_aligned, block_size);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
}
}
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
octx->src0_nrows_per_thread);
}
static int execute_op_unary_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
const struct htp_tensor * src0 = &octx->src0;
struct htp_tensor * dst = &octx->dst;
worker_callback_t unary_op_func;
const char * op_type = NULL;
const char * op_type = NULL;
switch (octx->op) {
case HTP_OP_RMS_NORM:
unary_op_func = unary_job_dispatcher_f32;
op_type = "rmsnorm-f32";
op_type = "rmsnorm-f32";
break;
case HTP_OP_SCALE:
unary_op_func = unary_job_dispatcher_f32;
op_type = "scale-f32";
op_type = "scale-f32";
break;
case HTP_OP_SQR:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqr-f32";
op_type = "sqr-f32";
break;
case HTP_OP_SQRT:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqrt-f32";
op_type = "sqrt-f32";
break;
default:
@ -294,32 +307,61 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
const size_t src0_row_size = src0->nb[1];
const size_t dst_row_size = dst->nb[1];
// VTCM scratchpads for all tensors
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
// VTCM scratchpads for all tensors
// N rows per thread, padded to HVX vector size
// Double buffering requires 2x size per buffer
size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
// Make sure the reserved vtcm size is sufficient
if (vtcm_row_per_thread == 0) {
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
spad_size_per_row * n_threads);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2;
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
// Make sure the reserved vtcm size is sufficient
if (octx->ctx->vtcm_size < spad_size) {
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
spad_size);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
struct htp_unary_context uctx = {
.octx = octx,
.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
.src0_nrows = src0_nrows,
worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
.data_src0 = (const uint8_t *)src0->data,
.data_dst = (uint8_t *)dst->data,
.src0_row_size = src0_row_size,
.dst_row_size = dst_row_size,
.src0_row_size_aligned = src0_row_size_aligned,
.dst_row_size_aligned = dst_row_size_aligned,
.src0_spad_half_size = octx->src0_spad.size_per_thread / 2,
.dst_spad_half_size = octx->dst_spad.size_per_thread / 2,
.block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
.nc = src0->ne[0],
};
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
}
return err;

View File

@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --batch-size 128 -fa on \
-ngl 99 --device $device $cli_opts $@ \
--ctx-size 8192 --ubatch-size 256 -fa on \
-ngl 99 --device $device $cli_opts $@ \
"

View File

@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --batch-size 128 -fa on \
-ngl 99 -no-cnv --device $device $cli_opts $@ \
--ctx-size 8192 --ubatch-size 256 -fa on \
-ngl 99 -no-cnv --device $device $cli_opts $@ \
"

View File

@ -58,11 +58,11 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
--mmproj $basedir/../gguf/$mmproj \
--image $basedir/../gguf/$image \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
-ngl 99 --device $device -v $cli_opts $@ \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
--mmproj $basedir/../gguf/$mmproj \
--image $basedir/../gguf/$image \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 256 -fa on \
-ngl 99 --device $device -v $cli_opts $@ \
"

View File

@ -49,5 +49,5 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib"
& "$basedir\bin\llama-completion.exe" `
--no-mmap -no-cnv -m $basedir\..\..\gguf\$model `
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on `
--ctx-size 8192 --ubatch-size 128 -fa on `
-ngl 99 --device $device $cli_opts