From f5d1c4179fedf726bec744d3125a55df8d02496a Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Sun, 29 Mar 2026 06:40:13 -0700 Subject: [PATCH] hexagon: dma optimizations (mostly fixing regressions) (#21137) * hex-fa: add simple dma cache for Mask I noticed that we were refetch the mask rows over and over. This simple cache avoids that. * hex-dma: unset in-order desc bit which caused signficant perf regression We don't rely on true in order processing of the DMA descriptors anywhere. Turns out this mode caused significant regression of around 3-4 TPS during token gen. * hex-rope: update comment to clarify that we don't need in-order DMA completions --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 12 ++-- ggml/src/ggml-hexagon/htp/hex-dma.h | 75 ++++++++++++++++++---- ggml/src/ggml-hexagon/htp/rope-ops.c | 4 +- 3 files changed, 74 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 6dc978dd68..0c9bc78562 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -346,6 +346,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap); + dma_cache m_cache; + dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE); + for (uint32_t ir = ir0; ir < ir1; ++ir) { const uint32_t iq3 = fastdiv(ir, &factx->src0_div21); const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1); @@ -389,9 +392,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); - uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block; // Mask is 1D contiguous for this row - dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", @@ -554,7 +556,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); - dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); } // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", @@ -684,7 +686,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->src0_spad.size_per_thread = size_q_block * 1; octx->src1_spad.size_per_thread = factx.size_k_block * 2; octx->src2_spad.size_per_thread = factx.size_v_block * 2; - octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0; + octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0; octx->dst_spad.size_per_thread = size_vkq_acc; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -705,6 +707,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + // FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread); + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); } diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index ff166cbcc7..7685473f46 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -143,7 +143,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t desc->desc_size = 0; // 1D mode desc->src_bypass = dma_src_l2_bypass_on; desc->dst_bypass = dma_dst_l2_bypass_on; - desc->order = 1; + desc->order = 0; desc->done = 0; desc->src = (void *) dptr.src; desc->dst = (void *) dptr.dst; @@ -151,8 +151,12 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t q->dptr[q->push_idx] = dptr; - dmlink(q->tail, desc); - q->tail = (dma_descriptor_2d *) desc; + if (size) { + dmlink(q->tail, desc); + q->tail = (dma_descriptor_2d *) desc; + } else { + desc->done = 1; + } // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; @@ -175,7 +179,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t desc->dst_bypass = dma_dst_l2_bypass_on; desc->src_comp = 0; desc->dst_comp = 0; - desc->order = 1; + desc->order = 0; desc->done = 0; desc->src_stride = src_stride; desc->dst_stride = dst_stride; @@ -197,8 +201,12 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t q->dptr[q->push_idx] = dptr; - dmlink(q->tail, desc); - q->tail = desc; + if (nrows) { + dmlink(q->tail, desc); + q->tail = desc; + } else { + desc->done = 1; + } // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; @@ -215,12 +223,9 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) { dma_descriptor_2d * desc = &q->desc[q->pop_idx]; // Wait for desc to complete - while (1) { - dmpoll(); - if (desc->done) { - break; - } + while (!desc->done) { // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); + dmpoll(); } dptr = q->dptr[q->pop_idx]; @@ -312,6 +317,54 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_ return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); } +#define DMA_CACHE_MAX_SIZE 64U + +typedef struct { + uint8_t *base; + uint32_t line_size; + uint32_t capacity; + uint32_t src[DMA_CACHE_MAX_SIZE]; + uint16_t age[DMA_CACHE_MAX_SIZE]; +} dma_cache; + +static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity) +{ + c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity; + c->base = base; + c->line_size = line_size; + + for (unsigned i=0; i < c->capacity; i++) { + c->src[i] = 0; + c->age[i] = 0; + } +} + +static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows) +{ + uint32_t o_idx = 0; + uint16_t o_age = 0; + uint8_t * dst = 0; + + for (unsigned i=0; i < c->capacity; i++) { + if (c->src[i] == (uint32_t) src) { + c->age[i] = 0; + dst = c->base + (i * c->line_size); nrows = 0; // dummy dma + // FARF(ERROR, "dma-cache: found %p", src); + } else { + c->age[i]++; + if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; } + } + } + if (!dst) { + // FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src); + c->age[o_idx] = 0; + c->src[o_idx] = (uint32_t) src; + dst = c->base + o_idx * c->line_size; // normal nrows dma + } + + return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows); +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index be9469538f..ecedadb0fe 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -333,8 +333,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); } - // Skip DMA transactions from prev block (if any) - // No need to wait for these since the DMA is setup for in-order processing + // Skip output DMA transactions from prev block (if any) + // No need to wait for those here since we're explicitly waiting for the latest prefecthes below. for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); } // Compute loop