hexagon : Add GEGLU op
This commit is contained in:
parent
f8bdccd967
commit
2f6c19c39d
|
|
@ -2450,6 +2450,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||||
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
|
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
|
||||||
req->op = HTP_OP_GLU_SWIGLU_OAI;
|
req->op = HTP_OP_GLU_SWIGLU_OAI;
|
||||||
supported = true;
|
supported = true;
|
||||||
|
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
|
||||||
|
req->op = HTP_OP_GLU_GEGLU;
|
||||||
|
supported = true;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
|
@ -2618,7 +2621,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||||
break;
|
break;
|
||||||
case GGML_OP_GLU:
|
case GGML_OP_GLU:
|
||||||
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
|
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
|
||||||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
|
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
|
||||||
|
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
|
||||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
@ -3039,7 +3043,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||||
case GGML_OP_GLU:
|
case GGML_OP_GLU:
|
||||||
{
|
{
|
||||||
const auto glu_op = ggml_get_glu_op(op);
|
const auto glu_op = ggml_get_glu_op(op);
|
||||||
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
|
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
|
||||||
supp = ggml_hexagon_supported_activations(sess, op);
|
supp = ggml_hexagon_supported_activations(sess, op);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const float GELU_COEF_A = 0.044715f;
|
||||||
|
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||||
|
|
||||||
|
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||||
|
const struct htp_tensor * src1,
|
||||||
|
struct htp_tensor * dst,
|
||||||
|
const int32_t * op_params,
|
||||||
|
struct htp_spad * src0_spad,
|
||||||
|
struct htp_spad * src1_spad,
|
||||||
|
struct htp_spad * dst_spad,
|
||||||
|
uint32_t nth,
|
||||||
|
uint32_t ith,
|
||||||
|
uint32_t src0_nrows_per_thread,
|
||||||
|
dma_queue * dma_queue) {
|
||||||
|
htp_act_preamble3;
|
||||||
|
|
||||||
|
size_t src0_row_size = nb01;
|
||||||
|
size_t src1_row_size = nb11;
|
||||||
|
size_t dst_row_size = nb1;
|
||||||
|
|
||||||
|
uint64_t t1, t2;
|
||||||
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||||
|
|
||||||
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||||
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||||
|
|
||||||
|
// no work for this thread
|
||||||
|
if (src0_start_row >= src0_end_row) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||||
|
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||||
|
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||||
|
|
||||||
|
const bool src1_valid = src1->ne[0];
|
||||||
|
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||||
|
if (!src1_valid) {
|
||||||
|
const int32_t swapped = op_params[1];
|
||||||
|
data_src1 = data_src0;
|
||||||
|
src1_row_size = src0_row_size;
|
||||||
|
|
||||||
|
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||||
|
data_src0 += swapped ? nc_in_bytes : 0;
|
||||||
|
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||||
|
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||||
|
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||||
|
|
||||||
|
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||||
|
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||||
|
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||||
|
|
||||||
|
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||||
|
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||||
|
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||||
|
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||||
|
|
||||||
|
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||||
|
if (BLOCK == 0) {
|
||||||
|
FARF(ERROR,
|
||||||
|
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||||
|
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||||
|
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||||
|
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||||
|
|
||||||
|
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||||
|
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||||
|
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||||
|
dst_row_size, dst_row_size_aligned, 0);
|
||||||
|
|
||||||
|
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||||
|
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
|
||||||
|
src0_row_size_aligned, src0_row_size, block_size);
|
||||||
|
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||||
|
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
|
||||||
|
src1_row_size_aligned, src1_row_size, block_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||||
|
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||||
|
|
||||||
|
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||||
|
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||||
|
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||||
|
|
||||||
|
for (uint32_t ib = 0; ib < block_size; ib++) {
|
||||||
|
const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
|
||||||
|
const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
|
||||||
|
uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
|
||||||
|
|
||||||
|
// geglu tanh implementation
|
||||||
|
// geglu(x, g) = gelu(x) * g
|
||||||
|
// gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
|
||||||
|
hvx_mul_f32_aa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
|
||||||
|
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
|
||||||
|
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
|
||||||
|
hvx_mul_f32_aa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
|
||||||
|
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
|
||||||
|
hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
|
||||||
|
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
|
||||||
|
hvx_mul_f32_aa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
|
||||||
|
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
|
||||||
|
hvx_mul_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
|
||||||
|
}
|
||||||
|
|
||||||
|
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
|
||||||
|
dst_row_size_aligned, block_size);
|
||||||
|
|
||||||
|
// prefetch N+2 loop iteration if any
|
||||||
|
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||||
|
if (pref_block < src0_end_row) {
|
||||||
|
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||||
|
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
|
||||||
|
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||||
|
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
|
||||||
|
src1_row_size_aligned, src1_row_size, pref_block_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dma_queue_flush(dma_queue);
|
||||||
|
|
||||||
|
t2 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||||
|
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||||
|
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
|
}
|
||||||
|
|
||||||
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
|
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
|
||||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||||
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||||
|
|
@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
|
||||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||||
|
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||||
|
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||||
|
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||||
|
}
|
||||||
|
|
||||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||||
int err = HTP_STATUS_OK;
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
|
@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||||
act_op_func = unary_gelu_f32;
|
act_op_func = unary_gelu_f32;
|
||||||
op_type = "gelu-f32";
|
op_type = "gelu-f32";
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_GLU_GEGLU:
|
||||||
|
act_op_func = glu_geglu_f32;
|
||||||
|
op_type = "geglu-f32";
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
|
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
|
||||||
return HTP_STATUS_NO_SUPPORT;
|
return HTP_STATUS_NO_SUPPORT;
|
||||||
|
|
|
||||||
|
|
@ -42,36 +42,36 @@ enum htp_data_type {
|
||||||
HTP_TYPE_COUNT
|
HTP_TYPE_COUNT
|
||||||
};
|
};
|
||||||
|
|
||||||
// These values are manually translated over to HTP
|
// Do not reorder first 4 (used as an index)
|
||||||
// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
|
|
||||||
enum htp_op {
|
enum htp_op {
|
||||||
HTP_OP_MUL = 0,
|
HTP_OP_MUL = 0,
|
||||||
HTP_OP_ADD = 1,
|
HTP_OP_ADD = 1,
|
||||||
HTP_OP_SUB = 2,
|
HTP_OP_SUB = 2,
|
||||||
HTP_OP_DIV = 3,
|
HTP_OP_DIV = 3,
|
||||||
HTP_OP_MUL_MAT = 4,
|
HTP_OP_MUL_MAT,
|
||||||
HTP_OP_MUL_MAT_ID = 5,
|
HTP_OP_MUL_MAT_ID,
|
||||||
HTP_OP_RMS_NORM = 6,
|
HTP_OP_RMS_NORM,
|
||||||
HTP_OP_UNARY_SILU = 7,
|
HTP_OP_UNARY_SILU,
|
||||||
HTP_OP_UNARY_GELU = 8,
|
HTP_OP_UNARY_GELU,
|
||||||
HTP_OP_GLU_SWIGLU = 9,
|
HTP_OP_GLU_SWIGLU,
|
||||||
HTP_OP_GLU_SWIGLU_OAI = 10,
|
HTP_OP_GLU_SWIGLU_OAI,
|
||||||
HTP_OP_SOFTMAX = 11,
|
HTP_OP_GLU_GEGLU,
|
||||||
HTP_OP_ADD_ID = 12,
|
HTP_OP_SOFTMAX,
|
||||||
HTP_OP_ROPE = 13,
|
HTP_OP_ADD_ID,
|
||||||
HTP_OP_FLASH_ATTN_EXT = 14,
|
HTP_OP_ROPE,
|
||||||
HTP_OP_SET_ROWS = 15,
|
HTP_OP_FLASH_ATTN_EXT,
|
||||||
HTP_OP_SCALE = 16,
|
HTP_OP_SET_ROWS,
|
||||||
HTP_OP_GET_ROWS = 17,
|
HTP_OP_GET_ROWS,
|
||||||
HTP_OP_CPY = 18,
|
HTP_OP_SCALE,
|
||||||
HTP_OP_ARGSORT = 19,
|
HTP_OP_CPY,
|
||||||
HTP_OP_SQR = 20,
|
HTP_OP_ARGSORT,
|
||||||
HTP_OP_SQRT = 21,
|
HTP_OP_SQR,
|
||||||
HTP_OP_SUM_ROWS = 22,
|
HTP_OP_SQRT,
|
||||||
|
HTP_OP_SUM_ROWS,
|
||||||
INVALID
|
INVALID
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline size_t htp_type_block_size(uint32_t t) {
|
static inline size_t htp_t_block_size(uint32_t t) {
|
||||||
switch (t) {
|
switch (t) {
|
||||||
case HTP_TYPE_F32:
|
case HTP_TYPE_F32:
|
||||||
return 1;
|
return 1;
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
|
||||||
} \
|
} \
|
||||||
} while(0)
|
} while(0)
|
||||||
|
|
||||||
|
#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
|
||||||
|
do { \
|
||||||
|
dst_type * restrict vdst = (dst_type *) dst; \
|
||||||
|
src_type * restrict vsrc = (src_type *) src; \
|
||||||
|
\
|
||||||
|
const uint32_t epv = 128 / sizeof(float); \
|
||||||
|
const uint32_t nvec = n / epv; \
|
||||||
|
const uint32_t nloe = n % epv; \
|
||||||
|
\
|
||||||
|
uint32_t i = 0; \
|
||||||
|
\
|
||||||
|
_Pragma("unroll(4)") \
|
||||||
|
for (; i < nvec; i++) { \
|
||||||
|
vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
|
||||||
|
} \
|
||||||
|
if (nloe) { \
|
||||||
|
HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
|
||||||
|
vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
|
||||||
|
} \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||||
assert((unsigned long) dst % 128 == 0);
|
assert((unsigned long) dst % 128 == 0);
|
||||||
assert((unsigned long) src % 128 == 0);
|
assert((unsigned long) src % 128 == 0);
|
||||||
|
|
@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re
|
||||||
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||||
|
assert((unsigned long) dst % 128 == 0);
|
||||||
|
assert((unsigned long) src % 128 == 0);
|
||||||
|
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||||
|
}
|
||||||
|
|
||||||
#endif /* HVX_SIGMOID_H */
|
#endif /* HVX_SIGMOID_H */
|
||||||
|
|
|
||||||
|
|
@ -1078,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||||
case HTP_OP_GLU_SWIGLU:
|
case HTP_OP_GLU_SWIGLU:
|
||||||
case HTP_OP_GLU_SWIGLU_OAI:
|
case HTP_OP_GLU_SWIGLU_OAI:
|
||||||
case HTP_OP_SOFTMAX:
|
case HTP_OP_SOFTMAX:
|
||||||
|
case HTP_OP_GLU_GEGLU:
|
||||||
if ((n_bufs != 2) && (n_bufs != 3)) {
|
if ((n_bufs != 2) && (n_bufs != 3)) {
|
||||||
FARF(ERROR, "Bad act-req buffer list");
|
FARF(ERROR, "Bad act-req buffer list");
|
||||||
continue;
|
continue;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue