refactor: optimize swiglu

This commit is contained in:
shouyud 2025-12-26 11:43:38 -05:00
parent 768f572178
commit fec5a9e077
1 changed files with 75 additions and 24 deletions

View File

@ -85,13 +85,16 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread) {
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
htp_act_preamble3;
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
@ -127,37 +130,86 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
data_src1 += swapped ? 0 : nc_in_bytes;
}
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1)));
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
if (ir + 1 < src0_end_row) {
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR, "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
}
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float* dst_spad = (float *) dma_queue_pop(dma_queue).src;
float* src0_spad = (float *) dma_queue_pop(dma_queue).dst;
float* src1_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
const float* src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
//swiglu(x) = x1 * sigmoid(x0)
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, (const uint8_t *) src1_spad_ptr,
(uint8_t *) dst_spad_ptr, nc);
}
if (opt_path) {
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
(uint8_t *) dst, nc);
} else {
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true);
hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc);
hvx_inverse_f32(src1_spad_data, src0_spad_data, nc);
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
dst_row_size, dst_row_size_aligned, block_size);
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc);
hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
src1_row_size_aligned, src1_row_size, pref_block_size);
}
}
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
@ -403,7 +455,6 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
// In gelu = x*sigmoid(x*1.702)
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
@ -472,7 +523,7 @@ static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {