From 36bc6f3213c0b808579ddff093f69e293a894caf Mon Sep 17 00:00:00 2001 From: nullname Date: Thu, 16 Oct 2025 23:21:51 +0800 Subject: [PATCH] feat: perf opt gemv phase2 (#58) * Add power management utilities to NPU device context and update DCVS settings * Update DCVS settings in power_utils to use v3 API and enhance power management * wip * Enhance dequantization functions by adding load_dequant_table support and updating signatures for improved performance * use lut * wip * fix test failure * wip * Refactor load_qual_block_generic to improve block handling and optimize vector operations * Enhance load_dual_block_generic and load_qual_block_generic to accept a mask parameter for improved block handling * Refactor flash_attn_impl to optimize mask l2 prefetch * wip * wip * wip * wip * add log * link against shared libraries instead of static ones * fix swiglu * wip * refactor expf_fix to handle overflow for different data types * enhance is_glu_op_supported to validate shapes for multiple sources * wip * refactor logging macros to use hexagon namespace and improve formatting * fix printf format error * wip * refactor: update static_assert messages for block size validation and add HVX_VectorPred_x3 type alias * rename * feat: enhance fa with mask * wip * wip * refactor: replace instances of Q6_V_vzero() with kZeroV for consistency * wip * wip * wip * fix: improve address alignment check in HVX_Vector handling * refactor: streamline vector dot product implementations for improved readability * refactor: q4k add hvx intrinsic impl * refactor: enhance dequantize_row_q4_K for clarity and performance * refactor: optimize scale mask usage in dequantization functions for improved performance * refactor: optimize dequantize_row_q4_K for intrinsic usage and performance improvements * refactor: move GLU operation implementation into separated file * sync after swiglu * wip * wip * wip * feat: increase prc main thread stack size * fix: replace hardcoded stack size with NPU_THREAD_STACK_SIZE constant * wip * feat: add optimized vector operations for exponential and division with overflow handling * wip * feat: refactor exponential function to handle overflow and underflow with improved logic * wip * wip * feat: add vector loading and scaling functions for improved performance in block processing * wip * feat: optimize block loading by refactoring scale index handling for improved performance * use Q6_Vb_vlut32_VbVbR_nomatch instead * feat: enhance scale loading by adding static assertion and restructuring block handling * wip * feat: refactor vec_dot_product_mixed_impl for improved clarity and performance * wip * feat: simplify vector loading functions and improve alignment handling * wip * feat: enhance scale loading mask with quantization block size validation * wip * feat: implement make_scale_load_mask function and refactor vector handling in vec_ops * feat: enhance load_dual_block_generic to include scale indices for improved vector loading * revert q8 dequant * wip * feat: optimize dequantization functions by removing unnecessary masking and updating lookup methods * wip * wip * add qurt_mutex * Add DMA transfer class and integrate into thread pool * Enhance DMA transfer functionality by adding support for multiple descriptors and initiating transfers in parallel * fix dma crash * fix failed unit tests * wip * use alignas * Improve DMA transfer error handling and update descriptor completion check * Fix VTCM cache size calculation in element-wise operations * Add cache clean operations before DMA transfers in element-wise operations * reduce cache clean operations * Refactor DMA transfer functions to support 1D operations and rename for clarity * Enhance DMA transfer functionality by adding 2D submission support and improving descriptor initialization * Update read buffer method to support forced invalidation and remove unnecessary invalidation calls in element-wise operations * wip * Improve DMA transfer handling in mul_mat_gemv_impl by replacing memcpy with initiate_dma_row_transfer and adding wait_for_dma logic * fix 2d dma * feat: add DMA plane cache * rename * wip * use memcpy for debug * fix cache plane calc * refactor: remove debug logging from mul_mat_impl and optimize cache handling * rename * fix 2d dma type * refactor: enhance DMA transfer handling in mul_mat_gemv_impl and wait functions * refactor: optimize DMA transfer handling in mul_mat_gemv_impl and wait functions * wip * wip * move op impl into sub dir * add log * fix: correct pointer usage in mul_mat_gemv_impl for next plane access * fix: improve DMA transfer error handling in mul_mat_impl and mul_mat_gemv_impl * fix: fix crash by using the entire row bytes * wip * wip * fix: prevent parallelization for scalar src1 in is_mul_mat_supported * fix: add dimension checks for 2D DMA transfers and fallback to 1D if necessary * wip * fix: enable thread barrier for mul multiplication operations * feat: add synchronization checks for tensor operations and update related functions * wip * fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations * Revert "fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations" This reverts commit af3441e67e706b2e5122369dc160353796867dd3. * wip * wip * add comment * fix: improve DMA transfer handling in mul_mat_gemv_impl for quantized source tensors * add log * try fix mulmat gemv * wip * fix: enhance DMA transfer handling in mul_mat_gemv_impl for quantized source tensors * fix: optimize cache offset calculation and remove redundant swap in mul_mat_gemv_impl * fix: refactor DMA transfer handling in mul_mat_gemv_impl for improved clarity and maintainability * wip * wip * wip * fix: enhance mul_mat_impl for improved cache handling and clarity * fix: refactor tensor unflattening and DMA transfer initialization for improved clarity and type safety * fix: improve cache handling of quant * wip * fix: improve cache handling in mul_mat_impl and mul_mat_gemv_impl for better memory efficiency * rename * add load_hexa_block_generic * wip * extract dequant block into separated function * refactor: enhance dequantization functions with table parameter * fix load_dual_block_generic * refactor: rename dequantization functions for clarity and enhance block handling * refactor: simplify dequantization logic by consolidating block handling and removing unused parameters * wip * wip * feat: add make_qs_load_mask function and update load_dual_block_generic to use qs_indices * fix load_dual_block_generic * refactor: update load functions to use qs_indices for improved block loading * wip * fix: update loop indices and boundary checks to use size_t for better efficiency * wip * update make_scale_load_mask, to make it available for q8 * feat: add vec_dot_product_quant_impl for quantized dot product computation * refactoring: move come quant func to dedicated file * refactor: rename dequantization functions for clarity and consistency * wip * feat: enhance vec_dot_product_quant_impl with dual dequantization and improved assertions * add vec_dot_product_vqf32_q40_f32 * wip * wip * wip * wip * implement vec_mpy_qf32_qf32_qf32 function and update vec_dot_product_vqf32_q40_f32 to use it * wip * add src0_plane_write_cache_offset * wip * enhance mul_mat_f32 to handle NPU_DATA_TYPE_Q4_0 for quantized matrix multiplication * wip * wip * update test func * refactor mul_mat_gemv_quant_impl to use get_nb for row stride and remove unused test function in init_f16_f32_table * wip * Add support for 4-block dequantization in vec_quant and update dot product implementation * Refactor vec_dot_product_quant_impl to improve variable handling and enhance readability * Refactor vec_dot_product_quant_impl to replace template function with inline vector operations * use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32 * Revert "use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32" This reverts commit 54839166fddbe40a0392adee5863c59070ccdbe4. * wip * improve log print in graph * Refactor batched_row_dot to accept additional arguments and remove batched_row_dot_with_table * Refactor synchronization functions to include previous operation and NE type parameters * Refactor synchronization checks in several operations * Update synchronization checks to include NPU_OP_COUNT in required conditions * Add performance tracking to buffer management functions * add memset * add log * fix: update backend device type from ACCEL to IGPU * fix comment --- ggml/src/ggml-qnn/npu/device/graph.cpp | 21 +- .../ggml-qnn/npu/device/op/op_flash_attn.cpp | 10 +- .../ggml-qnn/npu/device/op/op_flash_attn.hpp | 5 +- ggml/src/ggml-qnn/npu/device/op/op_glu.cpp | 44 ++- ggml/src/ggml-qnn/npu/device/op/op_glu.hpp | 5 +- ggml/src/ggml-qnn/npu/device/op/op_impl.cpp | 27 +- ggml/src/ggml-qnn/npu/device/op/op_impl.hpp | 5 +- .../src/ggml-qnn/npu/device/op/op_mul_mat.cpp | 291 +++++++++++++---- .../src/ggml-qnn/npu/device/op/op_mul_mat.hpp | 5 +- ggml/src/ggml-qnn/npu/device/op/op_rope.cpp | 9 +- ggml/src/ggml-qnn/npu/device/op/op_rope.hpp | 5 +- ggml/src/ggml-qnn/npu/device/op/op_types.hpp | 5 +- ggml/src/ggml-qnn/npu/device/type_traits.cpp | 302 +++-------------- ggml/src/ggml-qnn/npu/device/vec_ops.hpp | 61 ++-- ggml/src/ggml-qnn/npu/device/vec_ops.inl | 152 +++++++++ ggml/src/ggml-qnn/npu/device/vec_quant.inl | 304 ++++++++++++++++++ ggml/src/ggml-qnn/npu/host/buffer.cpp | 42 ++- ggml/src/ggml-qnn/npu/host/host.cpp | 23 +- ggml/src/ggml-qnn/npu/host/host_device.cpp | 37 +-- 19 files changed, 899 insertions(+), 454 deletions(-) create mode 100644 ggml/src/ggml-qnn/npu/device/vec_quant.inl diff --git a/ggml/src/ggml-qnn/npu/device/graph.cpp b/ggml/src/ggml-qnn/npu/device/graph.cpp index 06d0957479..e303046747 100644 --- a/ggml/src/ggml-qnn/npu/device/graph.cpp +++ b/ggml/src/ggml-qnn/npu/device/graph.cpp @@ -74,27 +74,32 @@ void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread hexagon::compute_params params = { thread_params, _f16_to_f32_table }; npu_device_tensor_op prev_op = NPU_OP_COUNT; + npu_device_ne_type prev_ne = {}; for (size_t i = 0; i < _tensor_count; ++i) { - auto * dst = _tensors[i]; - auto op = dst->get_op(); - auto * func = get_compute_func(dst); + auto * dst = _tensors[i]; + auto op = dst->get_op(); + const auto & ne = dst->get_info().ne; + const auto * op_name = op_get_name(op); + auto * func = get_compute_func(dst); if (!func) { - DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d not supported\n", (void *) this, i, op); + DEVICE_LOG_ERROR("[%p][%s]graph tensor[%zu] op not supported\n", (void *) this, op_name, i); return; } - const bool should_sync = requires_thread_barrier(prev_op, op); + const bool should_sync = requires_thread_barrier(prev_op, prev_ne, op, ne); if (pool && should_sync) { // For the last tensor, the thread pool will handle synchronization - DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this, - params.get_thread_index(), i, _tensor_count); + DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p][%s]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this, op_name, + params.get_thread_index(), i + 1, _tensor_count); pool->sync_thread(); } prev_op = op; + memcpy(&prev_ne, &ne, sizeof(prev_ne)); + if (!func(dst, ¶ms)) { - DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d compute failed\n", (void *) this, i, op); + DEVICE_LOG_ERROR("[%p][%s]graph tensor[%zu] op %d compute failed\n", (void *) this, op_name, i, op); } } } diff --git a/ggml/src/ggml-qnn/npu/device/op/op_flash_attn.cpp b/ggml/src/ggml-qnn/npu/device/op/op_flash_attn.cpp index 25089d9c44..fdf3fac5f6 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_flash_attn.cpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_flash_attn.cpp @@ -375,10 +375,14 @@ bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec, return true; } -bool is_flash_attn_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) { +bool is_flash_attn_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne) { + NPU_UNUSED(prev_ne); NPU_UNUSED(op); - NPU_UNUSED(next_op); - return true; + NPU_UNUSED(ne); + return prev_op != NPU_OP_COUNT; } } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op/op_flash_attn.hpp b/ggml/src/ggml-qnn/npu/device/op/op_flash_attn.hpp index cb27af5a9d..4ab64ac873 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_flash_attn.hpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_flash_attn.hpp @@ -9,6 +9,9 @@ bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, size_t src_len); -bool is_flash_attn_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op); +bool is_flash_attn_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op/op_glu.cpp b/ggml/src/ggml-qnn/npu/device/op/op_glu.cpp index e5e20fc548..2cdff3b982 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_glu.cpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_glu.cpp @@ -48,8 +48,8 @@ inline void glu_vec_op_f32_f32(const float * src0, size_t count, hexagon::HVX_VectorPair_x4 coeff) { using namespace hexagon::vec; - vec_trans_with_param_impl( - src0, src1, dst, count, coeff); + vec_trans_with_param_impl(src0, src1, dst, count, + coeff); } template @@ -73,8 +73,8 @@ bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) { const auto total_cols = has_src1 ? src0->get_ne(0) : src0->get_ne(0) / 2; if (out->get_ne(0) != total_cols) { - DEVICE_LOG_ERROR( - "[hexagon-npu][GLU]out.ne[0] (%ld) != total_cols (%d)\n", (long) out->get_ne(0), (int) total_cols); + DEVICE_LOG_ERROR("[hexagon-npu][GLU]out.ne[0] (%ld) != total_cols (%d)\n", (long) out->get_ne(0), + (int) total_cols); return false; } @@ -87,8 +87,7 @@ bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) { uint8_t * dst_ptr = out->get_write_buffer(); if (!dst_ptr) { - DEVICE_LOG_ERROR("[hexagon-npu][GLU]glu_impl: dst_ptr is not writable, tensor: %p, type: %s\n", - (void *) out, + DEVICE_LOG_ERROR("[hexagon-npu][GLU]glu_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out, hexagon::get_type_name(out->get_type())); return false; } @@ -121,11 +120,8 @@ bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) { hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes); } - _GluRowFunc(reinterpret_cast(src0_row), - reinterpret_cast(src1_row), - reinterpret_cast(dst_row), - static_cast(total_cols), - coeff); + _GluRowFunc(reinterpret_cast(src0_row), reinterpret_cast(src1_row), + reinterpret_cast(dst_row), static_cast(total_cols), coeff); } out->release_write_buffer(); // mark the output tensor as modified @@ -142,8 +138,7 @@ bool glu_compute(hexagon::tensor * out, hexagon::compute_params * params) { } if (out->get_type() != _DataType) { - DEVICE_LOG_ERROR("GLU op type mismatch: %s vs %s\n", - hexagon::get_type_name(out->get_type()), + DEVICE_LOG_ERROR("GLU op type mismatch: %s vs %s\n", hexagon::get_type_name(out->get_type()), hexagon::get_type_name(_DataType)); return false; } @@ -192,16 +187,14 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec, const auto & src0 = srcs[0]; if (dst->type != src0.type) { - DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", - hexagon::op_get_name(op), - hexagon::get_type_name(src0.type), - hexagon::get_type_name(dst->type)); + DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op), + hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type)); return false; } if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) { - DEVICE_LOG_DEBUG( - "[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type)); + DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), + hexagon::get_type_name(dst->type)); return false; } @@ -215,9 +208,7 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec, if (src0.ne[0] / 2 != dst->ne[0] || src0.ne[1] != dst->ne[1] || src0.ne[2] != dst->ne[2] || src0.ne[3] != dst->ne[3]) { DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape: src0.ne[0]: %ld, dst.ne[0]: %ld\n", - hexagon::op_get_name(op), - (long) src0.ne[0], - (long) dst->ne[0]); + hexagon::op_get_name(op), (long) src0.ne[0], (long) dst->ne[0]); return false; } } @@ -225,9 +216,14 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec, return true; } -bool is_glu_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) { +bool is_glu_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne) { + NPU_UNUSED(prev_ne); NPU_UNUSED(op); - return next_op == NPU_OP_MUL_MAT; + NPU_UNUSED(ne); + return prev_op == NPU_OP_MUL_MAT; } } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op/op_glu.hpp b/ggml/src/ggml-qnn/npu/device/op/op_glu.hpp index d1bf9b45f3..689e50231c 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_glu.hpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_glu.hpp @@ -11,6 +11,9 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, size_t src_len); -bool is_glu_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op); +bool is_glu_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op/op_impl.cpp b/ggml/src/ggml-qnn/npu/device/op/op_impl.cpp index dc978e180c..634baece92 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_impl.cpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_impl.cpp @@ -227,9 +227,15 @@ bool is_element_wise_op_supported(const npu_device_tensor_op_spec * op_spec, return true; } -bool is_element_wise_op_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) { +bool is_element_wise_op_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne) { + NPU_UNUSED(prev_ne); NPU_UNUSED(op); - return next_op == NPU_OP_MUL_MAT; + NPU_UNUSED(ne); + return prev_op != NPU_OP_ADD && prev_op != NPU_OP_SUB && prev_op != NPU_OP_MUL && prev_op != NPU_OP_RMS_NORM && + prev_op != NPU_OP_COUNT; } void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) { @@ -361,9 +367,15 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec, return true; } -bool is_unary_op_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) { +bool is_unary_op_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne) { + NPU_UNUSED(prev_ne); NPU_UNUSED(op); - return next_op == NPU_OP_MUL_MAT; + NPU_UNUSED(ne); + return prev_op != NPU_OP_ADD && prev_op != NPU_OP_SUB && prev_op != NPU_OP_MUL && prev_op != NPU_OP_RMS_NORM && + prev_op != NPU_OP_COUNT; } struct op_capabilities { @@ -457,13 +469,16 @@ compute_func_type get_compute_func(tensor * dst) { return get_compute_func_impl(dst->get_op(), dst->get_type()); } -bool requires_thread_barrier(npu_device_tensor_op op, npu_device_tensor_op next_op) { +bool requires_thread_barrier(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne) { if (op >= NPU_OP_COUNT) { return false; } auto requires_thread_barrier_func = kOpCapabilities[op].requires_thread_barrier_func; - return requires_thread_barrier_func && requires_thread_barrier_func(op, next_op); + return requires_thread_barrier_func && requires_thread_barrier_func(prev_op, prev_ne, op, ne); } bool support_op(const npu_device_tensor_op_spec * op_spec, diff --git a/ggml/src/ggml-qnn/npu/device/op/op_impl.hpp b/ggml/src/ggml-qnn/npu/device/op/op_impl.hpp index d8778046f7..09b625a995 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_impl.hpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_impl.hpp @@ -6,7 +6,10 @@ namespace hexagon { compute_func_type get_compute_func(tensor * dst); -bool requires_thread_barrier(npu_device_tensor_op op, npu_device_tensor_op next_op); +bool requires_thread_barrier(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne); bool support_op(const npu_device_tensor_op_spec * op_spec, const npu_device_tensor_spec * dst, diff --git a/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.cpp b/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.cpp index 063e8d105d..fa2aa9d55e 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.cpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.cpp @@ -6,6 +6,12 @@ namespace { +inline std::pair unflatten_i3_i2(size_t idx, const hexagon::tensor * t) { + const auto i3 = idx / t->get_ne(2); + const auto i2 = idx - i3 * t->get_ne(2); + return { i3, i2 }; +} + template struct get_data_type {}; template @@ -14,6 +20,12 @@ struct get_data_type { using data_type1 = _TData1; }; +template +struct get_data_type { + using data_type0 = _TData0; + using data_type1 = _TData1; +}; + template struct convert_vector {}; template <> struct convert_vector { @@ -55,29 +67,30 @@ inline bool init_dma_transfer(hexagon::compute_params * params, return true; } -template +template inline void batched_row_dot(const uint8_t * src0_plane, const size_t src0_ne0, const size_t src0_nb1, const uint8_t * src1_row, const size_t src1_nb1, float * dst_row, - const size_t actual_row_count, - const size_t src1_fetch_row_bytes) { + const size_t slice_rows, + const size_t src1_fetch_row_bytes, + _TExtraArgs... args) { using data_type0 = typename get_data_type::data_type0; using data_type1 = typename get_data_type::data_type1; size_t i0 = 0; - for (; i0 + 1 < actual_row_count; i0 += 2) { + for (; i0 + 1 < slice_rows; i0 += 2) { auto * src0_row = src0_plane + i0 * src0_nb1; // TODO: figure dst how to handle a entire row auto res0 = _DotFunc(reinterpret_cast(src0_row), - reinterpret_cast(src1_row), src0_ne0); + reinterpret_cast(src1_row), src0_ne0, args...); // TODO: figure dst how to handle a entire row auto res1 = _DotFunc(reinterpret_cast(src0_row + src0_nb1), - reinterpret_cast(src1_row), src0_ne0); + reinterpret_cast(src1_row), src0_ne0, args...); { dst_row[i0] = convert_vector::convert(res0); @@ -89,10 +102,10 @@ inline void batched_row_dot(const uint8_t * src0_plane, hexagon::l2fetch_row(src1_row + src1_nb1, src1_fetch_row_bytes); } - if (i0 < actual_row_count) { + if (i0 < slice_rows) { auto * src0_row = src0_plane + i0 * src0_nb1; auto res = _DotFunc(reinterpret_cast(src0_row), - reinterpret_cast(src1_row), src0_ne0); + reinterpret_cast(src1_row), src0_ne0, args...); dst_row[i0] = convert_vector::convert(res); } } @@ -105,7 +118,7 @@ inline void mul_mat_impl(hexagon::tensor * src0, using data_type0 = typename get_data_type::data_type0; using data_type1 = typename get_data_type::data_type1; - const auto src0_actual_row_stride = hexagon::get_dequantized_row_size(src0); + const auto src0_row_stride = hexagon::get_dequantized_row_size(src0); auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table; if (_IsSrcQuantized && dequantize_row_func == nullptr) { @@ -130,7 +143,8 @@ inline void mul_mat_impl(hexagon::tensor * src0, } if (start_end_plane.second <= start_end_plane.first || start_end_row.second <= start_end_row.first || - start_end_element.second <= start_end_element.first) { + start_end_element.second <= start_end_element.first || start_end_plane.first < 0 || start_end_row.first < 0 || + start_end_element.first < 0) { DEVICE_LOG_DEBUG( "mul_mat_impl: no work to do, start_end_plane: (%lld, %lld), start_end_row: (%lld, %lld), " "start_end_element: (%lld, %lld)\n", @@ -142,12 +156,12 @@ inline void mul_mat_impl(hexagon::tensor * src0, const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation // cache the src0 plane in VTCM - const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0)); - const size_t src1_actual_row_stride = hexagon::get_aligned_size(src1->get_nb(1)); + const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0)); + const size_t src1_row_stride = hexagon::get_aligned_size(src1->get_nb(1)); // TODO: figure out why we have to add padding after src0 plane cache const size_t src0_plane_slice_row_count = - std::min((params->get_vtcm_quota_size() - src1_actual_row_stride) / (src0_actual_row_stride * 2), + std::min((params->get_vtcm_quota_size() - src1_row_stride) / (src0_row_stride * 2), start_end_element.second - start_end_element.first); uint8_t * src0_plane_read_cache_ptr = nullptr; uint8_t * src0_plane_write_cache_ptr = nullptr; @@ -156,13 +170,13 @@ inline void mul_mat_impl(hexagon::tensor * src0, const uint8_t * last_read_cached_plane_ptr = nullptr; { - const size_t src0_plane_cache_size = src0_actual_row_stride * src0_plane_slice_row_count; + const size_t src0_plane_cache_size = src0_row_stride * src0_plane_slice_row_count; src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2); if (!src0_plane_read_cache_ptr) { DEVICE_LOG_ERROR( "mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, " - "src0_actual_row_stride: %zu, will fallback to mem cache\n", - src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_stride); + "src0_row_stride: %zu, will fallback to mem cache\n", + src0_plane_cache_size, src0_plane_slice_row_count, src0_row_stride); return; } @@ -173,11 +187,11 @@ inline void mul_mat_impl(hexagon::tensor * src0, } DEVICE_LOG_DEBUG( - "[%d]mul_mat_impl, src0_actual_row_stride:%zu, valid_src0_row_bytes:%zu, src_nb0:%zu, " + "[%d]mul_mat_impl, src0_row_stride:%zu, valid_src0_row_bytes:%zu, src_nb0:%zu, " "slice_row_count:%zu, write_cache_offset: %zu, " "total_planes:%lld, planes:[%d,%d), rows:[%d,%d), elems:[%d,%d), is_quant:%d, " "vtcm_mem:%p(%zu)\n", - (int) params->get_thread_index(), src0_actual_row_stride, valid_src0_row_bytes, (size_t) src0->get_nb(1), + (int) params->get_thread_index(), src0_row_stride, valid_src0_row_bytes, (size_t) src0->get_nb(1), src0_plane_slice_row_count, src0_plane_write_cache_offset, total_planes, (int) start_end_plane.first, (int) start_end_plane.second, (int) start_end_row.first, (int) start_end_row.second, (int) start_end_element.first, (int) start_end_element.second, _IsSrcQuantized, @@ -205,7 +219,7 @@ inline void mul_mat_impl(hexagon::tensor * src0, last_write_cached_plane_ptr = src0_plane; } - const size_t valid_row1_bytes = + const size_t valid_src1_row_bytes = src0->get_ne(0) * sizeof(data_type1); // src0 and src1 should have the same element count in the 1st dimension DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat); @@ -216,19 +230,19 @@ inline void mul_mat_impl(hexagon::tensor * src0, return; } - const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector(); const uint8_t * src1_ptr = src1->get_read_buffer(); - for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) { + const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector(); + for (size_t ip = start_end_plane.first; ip < size_t(start_end_plane.second); ip++) { const auto [i3, i2] = unflatten_i3_i2(ip, dst); const auto * src1_plane = src1_ptr + i3 * src1->get_nb(3) + i2 * src1->get_nb(2); auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2); const uint8_t * src0_plane_base = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2); - for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second; + for (size_t col_idx = start_end_element.first; col_idx < size_t(start_end_element.second); col_idx += src0_plane_slice_row_count) { const uint8_t * src0_plane = src0_plane_base + col_idx * src0->get_nb(1); - const int64_t actual_row_count = - std::min(src0_plane_slice_row_count, - start_end_element.second - col_idx); // number of rows in this slice + const size_t slice_rows = + std::min(src0_plane_slice_row_count, + start_end_element.second - col_idx); // number of rows in this slice { const uint8_t * src0_next_plane = last_write_cached_plane_ptr; @@ -272,9 +286,9 @@ inline void mul_mat_impl(hexagon::tensor * src0, if (last_read_cached_plane_ptr != src0_plane) { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant); const uint8_t * src0_quant_plane = src0_plane_read_cache_ptr + src0_plane_write_cache_offset; - for (int64_t ir = 0; ir < actual_row_count; ir++) { + for (size_t ir = 0; ir < slice_rows; ir++) { auto * src0_row = src0_quant_plane + ir * src0->get_nb(1); - auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_stride; + auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_row_stride; dequantize_row_func(src0_row, reinterpret_cast(cached_row_ptr), src0->get_ne(0), dequant_table); } @@ -284,16 +298,16 @@ inline void mul_mat_impl(hexagon::tensor * src0, last_read_cached_plane_ptr = src0_plane; if (start_end_row.second > start_end_row.first) { - hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_row1_bytes); + hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_src1_row_bytes); } - for (int64_t i1 = start_end_row.first; i1 < start_end_row.second; i1++) { + for (size_t i1 = start_end_row.first; i1 < size_t(start_end_row.second); i1++) { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dot); auto * src1_row = src1_plane + i1 * src1->get_nb(1); auto * dst_row = reinterpret_cast(dst_plane + i1 * dst->get_nb(1)) + col_idx; - batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_stride, src1_row, - src1->get_nb(1), dst_row, actual_row_count, - (ip + 1 < start_end_plane.second) ? valid_row1_bytes : 0); + batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_row_stride, src1_row, + src1->get_nb(1), dst_row, slice_rows, + (ip + 1 < start_end_plane.second) ? valid_src1_row_bytes : 0); } } } @@ -309,25 +323,22 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, using data_type0 = typename get_data_type::data_type0; using data_type1 = typename get_data_type::data_type1; - const auto src0_actual_row_stride = hexagon::get_dequantized_row_size(src0); - auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; - auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table; + auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; + auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table; if (_IsSrcQuantized && dequantize_row_func == nullptr) { DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type()); return; } - auto start_end_element = std::pair{ 0, dst->get_ne(0) }; - if (dst->get_ne(0) >= params->get_thread_count()) { - start_end_element = params->get_work_slice(dst->get_ne(0)); - } else { + if (dst->get_ne(0) < params->get_thread_count()) { DEVICE_LOG_ERROR("Unsupported src1 tensor shape for gemv: %s, ne: %lldx%lldx%lldx%lld\n", hexagon::get_type_name(src1->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2), src1->get_ne(3)); return; } - if (start_end_element.second <= start_end_element.first) { + const auto start_end_element = params->get_work_slice(dst->get_ne(0)); + if (start_end_element.second <= start_end_element.first || start_end_element.first < 0) { DEVICE_LOG_DEBUG( "mul_mat_gemv_impl: no work to do, start_end_plane: [0, 1), start_end_row: [0, 1), " "start_end_element: [%lld, %lld)\n", @@ -335,13 +346,14 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, return; } + const auto src0_row_stride = hexagon::get_dequantized_row_size(src0); const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0)); // cache the src0 plane in VTCM - const size_t src1_actual_row_stride = hexagon::get_aligned_size(src1->get_nb(1)); + const size_t src1_row_stride = hexagon::get_aligned_size(src1->get_nb(1)); const size_t src0_plane_slice_row_count = - std::min((params->get_vtcm_quota_size() - src1_actual_row_stride) / (src0_actual_row_stride * 2), + std::min((params->get_vtcm_quota_size() - src1_row_stride) / (src0_row_stride * 2), start_end_element.second - start_end_element.first); uint8_t * src0_plane_read_cache_ptr = nullptr; @@ -351,13 +363,13 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat); { - const size_t src0_plane_cache_size = src0_actual_row_stride * src0_plane_slice_row_count; - src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_actual_row_stride); + const size_t src0_plane_cache_size = src0_row_stride * src0_plane_slice_row_count; + src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_row_stride); if (!src0_plane_read_cache_ptr) { DEVICE_LOG_ERROR( "mul_mat_gemv_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, " - "src0_actual_row_stride: %zu, will fallback to mem cache\n", - src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_stride); + "src0_row_stride: %zu, will fallback to mem cache\n", + src0_plane_cache_size, src0_plane_slice_row_count, src0_row_stride); return; } @@ -369,9 +381,9 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, } DEVICE_LOG_DEBUG( - "mul_mat_gemv_impl: src0_actual_row_stride: %zu, src0_plane_slice_row_count: %zu, " + "mul_mat_gemv_impl: src0_row_stride: %zu, src0_plane_slice_row_count: %zu, " "src0_plane_write_cache_offset: %zu, src0.nb[1]: %d, is_quantized: %d, vtcm_mem: %p(%zu)\n", - src0_actual_row_stride, src0_plane_slice_row_count, src0_plane_write_cache_offset, int(src0->get_nb(1)), + src0_row_stride, src0_plane_slice_row_count, src0_plane_write_cache_offset, int(src0->get_nb(1)), _IsSrcQuantized, (void *) src0_plane_read_cache_ptr, src0_plane_cache_size); } @@ -382,8 +394,7 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, return; } - auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector(); - const uint8_t * src1_ptr = src1->get_read_buffer(); + const uint8_t * src1_ptr = src1->get_read_buffer(); { if (!params->initiate_dma_row_transfer(src1_ptr, src1_row_cache_ptr, src1->get_ne(0) * sizeof(data_type1))) { @@ -392,9 +403,9 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, } const uint8_t * src0_plane = src0_ptr + start_end_element.first * src0->get_nb(1); - const int64_t next_row_count = - std::min(src0_plane_slice_row_count, - start_end_element.second - start_end_element.first); // number of rows in this slice + const size_t next_row_count = + std::min(src0_plane_slice_row_count, + start_end_element.second - start_end_element.first); // number of rows in this slice params->wait_for_dma(); if (!init_dma_transfer<_IsSrcQuantized>( @@ -406,22 +417,23 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, } } + const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector(); { - for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second; + for (size_t col_idx = start_end_element.first; col_idx < size_t(start_end_element.second); col_idx += src0_plane_slice_row_count) { - const int64_t actual_row_count = - std::min(src0_plane_slice_row_count, - start_end_element.second - col_idx); // number of rows in this slice - const auto next_col_idx = col_idx + src0_plane_slice_row_count; + const size_t slice_rows = + std::min(src0_plane_slice_row_count, + start_end_element.second - col_idx); // number of rows in this slice + const size_t next_col_idx = col_idx + src0_plane_slice_row_count; std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr); params->wait_for_dma(); if (next_col_idx < start_end_element.second) { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, dma); const uint8_t * src0_next_plane = src0_ptr + next_col_idx * src0->get_nb(1); - const int64_t next_row_count = - std::min(src0_plane_slice_row_count, - start_end_element.second - next_col_idx); // number of rows in this slice + const size_t next_row_count = + std::min(src0_plane_slice_row_count, + start_end_element.second - next_col_idx); // number of rows in this slice if (!init_dma_transfer<_IsSrcQuantized>( params, src0_next_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset, valid_src0_row_bytes, next_row_count, src0->get_nb(1), src0->get_nb(1))) { @@ -435,9 +447,9 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, if constexpr (_IsSrcQuantized) { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant); const uint8_t * src0_quant_plane = src0_plane_read_cache_ptr + src0_plane_write_cache_offset; - for (int64_t ir = 0; ir < actual_row_count; ir++) { + for (size_t ir = 0; ir < slice_rows; ir++) { auto * src0_row = src0_quant_plane + ir * src0->get_nb(1); - auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_stride; + auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_row_stride; dequantize_row_func(src0_row, reinterpret_cast(cached_row_ptr), src0->get_ne(0), dequant_table); } @@ -446,8 +458,135 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dot); auto * dst_row = reinterpret_cast(dst_ptr) + col_idx; - batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_stride, - src1_row_cache_ptr, src1->get_nb(1), dst_row, actual_row_count, 0); + batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_row_stride, + src1_row_cache_ptr, src1->get_nb(1), dst_row, slice_rows, 0); + } + } + } + + dst->release_write_buffer(); // mark the output tensor as modified +} + +template +inline void mul_mat_gemv_quant_impl(hexagon::tensor * src0, + hexagon::tensor * src1, + hexagon::tensor * dst, + hexagon::compute_params * params) { + // TODO: merge with mul_mat_gemv_impl? + + using data_type1 = typename get_data_type::data_type1; + + if (dst->get_ne(0) < params->get_thread_count()) { + DEVICE_LOG_ERROR("Unsupported src1 tensor shape for gemv: %s, ne: %lldx%lldx%lldx%lld\n", + hexagon::get_type_name(src1->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2), + src1->get_ne(3)); + return; + } + + const auto src0_row_stride = src0->get_nb(1); + const auto start_end_element = params->get_work_slice(dst->get_ne(0)); + if (start_end_element.second <= start_end_element.first || start_end_element.first < 0) { + DEVICE_LOG_DEBUG( + "mul_mat_gemv_quant_impl: no work to do, start_end_plane: [0, 1), start_end_row: [0, 1), " + "start_end_element: [%lld, %lld)\n", + start_end_element.first, start_end_element.second); + return; + } + + const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation + + // cache the src0 plane in VTCM + const size_t src1_row_stride = hexagon::get_aligned_size(src1->get_nb(1)); + const size_t src0_plane_slice_row_count = + std::min((params->get_vtcm_quota_size() - src1_row_stride) / (src0_row_stride * 2), + start_end_element.second - start_end_element.first); + + uint8_t * src0_plane_read_cache_ptr = nullptr; + uint8_t * src0_plane_write_cache_ptr = nullptr; + uint8_t * src1_row_cache_ptr = nullptr; + + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat); + { + const size_t src0_plane_cache_size = src0_row_stride * src0_plane_slice_row_count; + src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_row_stride); + if (!src0_plane_read_cache_ptr) { + DEVICE_LOG_ERROR( + "mul_mat_gemv_quant_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: " + "%zu, " + "src0_row_stride: %zu, will fallback to mem cache\n", + src0_plane_cache_size, src0_plane_slice_row_count, src0_row_stride); + return; + } + + src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size; + src1_row_cache_ptr = src0_plane_write_cache_ptr + src0_plane_cache_size; + + DEVICE_LOG_DEBUG( + "mul_mat_gemv_quant_impl: src0_row_stride: %zu, src0_plane_slice_row_count: %zu, src0.nb[1]: %d, vtcm_mem: " + "%p(%zu)\n", + src0_row_stride, src0_plane_slice_row_count, int(src0->get_nb(1)), (void *) src0_plane_read_cache_ptr, + src0_plane_cache_size); + } + + uint8_t * dst_ptr = dst->get_write_buffer(); + if (!dst_ptr) { + DEVICE_LOG_ERROR("mul_mat_gemv_quant_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst, + hexagon::get_type_name(dst->get_type())); + return; + } + + const uint8_t * src1_ptr = src1->get_read_buffer(); + + { + if (!params->initiate_dma_row_transfer(src1_ptr, src1_row_cache_ptr, src1->get_ne(0) * sizeof(data_type1))) { + DEVICE_LOG_ERROR("mul_mat_gemv_quant_impl: failed to initiate dma transfer for src1\n"); + return; + } + + const uint8_t * src0_plane = src0_ptr + start_end_element.first * src0_row_stride; + const size_t next_row_count = + std::min(src0_plane_slice_row_count, + start_end_element.second - start_end_element.first); // number of rows in this slice + params->wait_for_dma(); + + if (!init_dma_transfer(params, src0_plane, src0_plane_write_cache_ptr, src0_row_stride, next_row_count, + src0_row_stride, src0_row_stride)) { + DEVICE_LOG_ERROR("mul_mat_gemv_quant_impl: failed to initiate dma plane transfer for src0 plane\n"); + return; + } + } + + auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table; + const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector(); + { + for (size_t col_idx = start_end_element.first; col_idx < size_t(start_end_element.second); + col_idx += src0_plane_slice_row_count) { + const size_t slice_rows = + std::min(src0_plane_slice_row_count, + start_end_element.second - col_idx); // number of rows in this slice + const size_t next_col_idx = col_idx + src0_plane_slice_row_count; + std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr); + params->wait_for_dma(); + + if (next_col_idx < start_end_element.second) { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dma); + const uint8_t * src0_next_plane = src0_ptr + next_col_idx * src0_row_stride; + const size_t next_row_count = + std::min(src0_plane_slice_row_count, + start_end_element.second - next_col_idx); // number of rows in this slice + if (!init_dma_transfer(params, src0_next_plane, src0_plane_write_cache_ptr, src0_row_stride, + next_row_count, src0_row_stride, src0_row_stride)) { + DEVICE_LOG_ERROR("mul_mat_gemv_quant_impl: failed to continue dma plane transfer for src0 plane\n"); + return; + } + } + + { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dot); + auto * dst_row = reinterpret_cast(dst_ptr) + col_idx; + batched_row_dot<_DotFunc, const HVX_Vector>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_row_stride, + src1_row_cache_ptr, src1->get_nb(1), dst_row, slice_rows, 0, + dequant_table); } } } @@ -641,8 +780,13 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) { switch (src1->get_type()) { case NPU_DATA_TYPE_F32: if (is_src0_quantized) { - kMulMatF16F32QuantizedFuncs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, true, is_gemv) + - base_index](src0, src1, out, params); + if (is_gemv && src0->get_type() == NPU_DATA_TYPE_Q4_0) { + // TODO: move to array + mul_mat_gemv_quant_impl(src0, src1, out, params); + } else { + kMulMatF16F32QuantizedFuncs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, true, is_gemv) + + base_index](src0, src1, out, params); + } } else if (src0->get_type() == NPU_DATA_TYPE_F16) { kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, true, is_gemv) + base_index]( src0, src1, out, params); @@ -664,7 +808,7 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) { break; } - DEVICE_LOG_ERROR("Unsupported src1 tensor type: %s\n", get_type_name(src1->get_type())); + DEVICE_LOG_ERROR("[MUL_MAT]Unsupported src1 tensor type: %s\n", get_type_name(src1->get_type())); return false; } @@ -741,10 +885,15 @@ bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec, return true; } -bool is_mul_mat_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) { +bool is_mul_mat_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne) { + NPU_UNUSED(prev_op); + NPU_UNUSED(prev_ne); NPU_UNUSED(op); - NPU_UNUSED(next_op); - return true; + NPU_UNUSED(ne); + return prev_op != NPU_OP_MUL_MAT || !is_same_shape(prev_ne, ne); } } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.hpp b/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.hpp index ef6391994c..e5ba79c23e 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.hpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.hpp @@ -12,6 +12,9 @@ bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, size_t src_len); -bool is_mul_mat_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op); +bool is_mul_mat_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op/op_rope.cpp b/ggml/src/ggml-qnn/npu/device/op/op_rope.cpp index 1edcf16480..76401b1e0c 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_rope.cpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_rope.cpp @@ -392,9 +392,14 @@ bool is_rope_supported(const npu_device_tensor_op_spec * op_spec, return true; // ROPE operation is not supported yet } -bool is_rope_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) { +bool is_rope_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne) { + NPU_UNUSED(prev_op); + NPU_UNUSED(prev_ne); NPU_UNUSED(op); - NPU_UNUSED(next_op); + NPU_UNUSED(ne); return false; } diff --git a/ggml/src/ggml-qnn/npu/device/op/op_rope.hpp b/ggml/src/ggml-qnn/npu/device/op/op_rope.hpp index af67ea46d5..9482facc27 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_rope.hpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_rope.hpp @@ -9,6 +9,9 @@ bool is_rope_supported(const npu_device_tensor_op_spec * op_spec, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, size_t src_len); -bool is_rope_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op); +bool is_rope_required_sync(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op/op_types.hpp b/ggml/src/ggml-qnn/npu/device/op/op_types.hpp index cb8aac26f5..d15e381ebd 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_types.hpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_types.hpp @@ -83,6 +83,9 @@ typedef bool (*op_is_supported_func_type)(const npu_device_tensor_op_spec * op_s const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, size_t src_len); -typedef bool (*op_required_sync_func_type)(const npu_device_tensor_op op, const npu_device_tensor_op next_op); +typedef bool (*op_required_sync_func_type)(npu_device_tensor_op prev_op, + const npu_device_ne_type & prev_ne, + npu_device_tensor_op op, + const npu_device_ne_type & ne); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/type_traits.cpp b/ggml/src/ggml-qnn/npu/device/type_traits.cpp index cc0ca77b38..1dced1d610 100644 --- a/ggml/src/ggml-qnn/npu/device/type_traits.cpp +++ b/ggml/src/ggml-qnn/npu/device/type_traits.cpp @@ -3,8 +3,6 @@ #include "op_types.hpp" // TODO: remove this include #include "vec_ops.hpp" -#include - static_assert(sizeof(npu_device_block_q4_k) == 2 * sizeof(npu_device_fp16_t) + QUANT_K_SCALE_SIZE + QUANT_K_BLOCK_SIZE / 2, "wrong q4_K block size/padding"); @@ -26,168 +24,6 @@ inline npu_device_fp16_t to_fp16(const float src) { return reinterpret_cast(f16_value); } -template inline HVX_Vector load_into_vector(const _TStruct * src) { - static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load"); - - return *reinterpret_cast(&(src->*_MemberPtr)); -} - -template inline HVX_Vector load_struct_into_vector(const _TStruct * src) { - static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load"); - - return *reinterpret_cast(src); -} - -template inline HVX_Vector load_block_generic(const _TBlock & src) { - static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock), "wrong block size/padding"); - return load_into_vector<_TBlock, 1, &_TBlock::qs>(&src); -} - -template inline HVX_Vector make_scale_load_mask() { - static_assert(sizeof(_TBlock) + sizeof(npu_device_fp16_t) < 32, "wrong block size/padding"); - static_assert(sizeof(_TBlock::qs) == 16 || sizeof(_TBlock::qs) == 32, "wrong quantization block size"); - - constexpr const size_t kScaleBlockSize = QUANT_BLOCK_SIZE * sizeof(hexagon::dequant_output_type); - - // TODO: handle the case that scale not at the start of struct - hexagon::HVX_VectorAlias ret; - for (size_t i = 0; i < QUANT_BLOCK_SIZE; ++i) { - size_t base = i * 2; - ret.u8[base] = 0; - ret.u8[base + 1] = 1; - - ret.u8[base + kScaleBlockSize] = sizeof(_TBlock); - ret.u8[base + kScaleBlockSize + 1] = sizeof(_TBlock) + 1; - } - - return ret.v; -} - -template inline HVX_Vector load_dual_block_generic(const _TBlock * srcs, HVX_VectorPred mask) { - static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding"); - constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); - constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs; - - HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs); - HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale); - return Q6_V_vmux_QVV(mask, blocks, block1); -} - -template -inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs, - HVX_VectorPred mask, - const HVX_Vector scale_indices) { - static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding"); - constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); - constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs; - - hexagon::HVX_Vector_x2 result; - - const HVX_Vector blocks = load_struct_into_vector<_TBlock, 2>(srcs); - - HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale); - HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2); - HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks); - - result.val[0] = Q6_V_vmux_QVV(mask, block0, block1); - result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0); - - return result; -} - -template inline hexagon::HVX_VectorPred_x3 make_quad_block_mask() { - static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 4, "wrong block size/padding"); - constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); - - hexagon::HVX_VectorPred_x3 mask; - mask.val[0] = Q6_Q_vsetq_R(kSizeOfQs); - mask.val[1] = Q6_Q_vsetq_R(kSizeOfQs * 3); - mask.val[2] = Q6_Q_vsetq_R(kSizeOfQs * 2); - return mask; -} - -template -inline hexagon::HVX_Vector_x3 load_qual_block_generic(const _TBlock * srcs, - const hexagon::HVX_VectorPred_x3 mask, - const HVX_Vector scale_indices) { - static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 4, "wrong block size/padding"); - constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); - constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs; - - hexagon::HVX_Vector_x3 result; - - const HVX_Vector blocks = load_struct_into_vector<_TBlock, 4>(srcs); - - { - HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale); - HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2); - - HVX_Vector block2 = Q6_V_vror_VR(blocks, kSizeOfScale * 3); - HVX_Vector block3 = Q6_V_vror_VR(blocks, kSizeOfScale * 4); - - HVX_Vector block01 = Q6_V_vmux_QVV(mask.val[0], block0, block1); - HVX_Vector block23 = Q6_V_vmux_QVV(mask.val[1], block2, block3); - - result.val[0] = Q6_V_vmux_QVV(mask.val[2], block01, block23); - } - - { - HVX_Vector scale23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2); - - HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks); - scale23 = Q6_Vb_vshuff_Vb(scale23); - - result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0); - result.val[2] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale23, 0); - } - - return result; -} - -template -inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs, - const hexagon::HVX_VectorPred_x3 mask, - const HVX_Vector scale_indices) { - static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 6, "wrong block size/padding"); - constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); - constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs; - - const HVX_Vector blocks = load_struct_into_vector<_TBlock, 6>(srcs); - - hexagon::HVX_Vector_x5 result; - { - HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale); - HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2); - - HVX_Vector block2 = Q6_V_vror_VR(blocks, kSizeOfScale * 3); - HVX_Vector block3 = Q6_V_vror_VR(blocks, kSizeOfScale * 4); - - HVX_Vector block4 = Q6_V_vror_VR(blocks, kSizeOfScale + sizeof(_TBlock) * 4); - HVX_Vector block5 = Q6_V_vror_VR(blocks, kSizeOfScale * 2 + sizeof(_TBlock) * 4); - - HVX_Vector block01 = Q6_V_vmux_QVV(mask.val[0], block0, block1); - HVX_Vector block23 = Q6_V_vmux_QVV(mask.val[1], block2, block3); - - result.val[0] = Q6_V_vmux_QVV(mask.val[2], block01, block23); - result.val[3] = Q6_V_vmux_QVV(mask.val[0], block4, block5); // block45 - } - - { - HVX_Vector scale23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2); - HVX_Vector scale45 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 4); - - HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks); - scale23 = Q6_Vb_vshuff_Vb(scale23); - scale45 = Q6_Vb_vshuff_Vb(scale45); - - result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0); - result.val[2] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale23, 0); - result.val[4] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale45, 0); - } - - return result; -} - inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { // TODO: use intrinsics if (j < 4) { @@ -448,25 +284,25 @@ void quantize_row_q4_K(const float * src, void * dst, size_t count) { } void dequantize_row_q8_0(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector) { + using namespace hexagon::vec::quant; + constexpr const int qk = QUANT_BLOCK_SIZE; static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); - const int nb = count / qk; - const auto * src_ptr = reinterpret_cast(src); - auto * dst_ptr = ((hexagon::dequant_output_type *) dst); // TODO: opt for aligned access - const HVX_VectorPred mask = Q6_Q_vsetq_R(sizeof(npu_device_block_q8_0::qs)); - const HVX_VectorPred scale_mask = Q6_Q_vsetq_R(hexagon::kBytesPerVector / 2); + alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices = make_qs_load_mask(); + alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices = + make_scale_load_mask(); + + const int nb = count / qk; + const auto * src_ptr = reinterpret_cast(src); + auto * dst_ptr = ((hexagon::dequant_output_type *) dst); // TODO: opt for aligned access int i = 0; for (; i + 1 < nb; i += 2) { - const auto & src0 = src_ptr[i]; - const auto & src1 = src_ptr[i + 1]; + auto qs = load_dual_block_generic(src_ptr + i, qs_indices, scale_indices); - HVX_Vector scales01 = Q6_V_vmux_QVV(scale_mask, Q6_Vh_vsplat_R(src0.d), Q6_Vh_vsplat_R(src1.d)); - - HVX_Vector qs = load_dual_block_generic(src_ptr + i, mask); - HVX_Vector q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(Q6_Wh_vunpack_Vb(qs))); - HVX_Vector result = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01); + HVX_Vector q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(Q6_Wh_vunpack_Vb(qs.val[0]))); + HVX_Vector result = Q6_Vqf16_vmpy_VhfVhf(q_lo, qs.val[1]); *reinterpret_cast(dst_ptr) = Q6_Vhf_equals_Vqf16(result); dst_ptr += qk * 2; @@ -486,75 +322,17 @@ void dequantize_row_q8_0(const void * src, hexagon::dequant_output_type * dst, s } } -template -inline void dequantize_row_q4_0_2blocks(HVX_Vector qs, - HVX_Vector scale01, - HVX_Vector table, - hexagon::dequant_output_type * dst_ptr) { - constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); - - HVX_Vector q_lo = qs; - HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4); - HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2)); - - q_lo = Q6_V_lo_W(qp0); - q_lo = Q6_Vb_vshuff_Vb(q_lo); - qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); - - q_lo = Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(qp0), scale01); - q_lo = Q6_Vhf_equals_Vqf16(q_lo); - - if constexpr (_IsDstAligned) { - *reinterpret_cast(dst_ptr) = q_lo; - } else { - *reinterpret_cast(dst_ptr) = q_lo; - } -} - -template -inline void dequantize_row_q4_0_4blocks(HVX_Vector qs, - HVX_Vector scale01, - HVX_Vector scale23, - HVX_Vector table, - hexagon::dequant_output_type * dst_ptr) { - constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); - - HVX_Vector q_lo = qs; - HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4); - - HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4)); - - q_lo = Q6_V_lo_W(qp0); - q_lo = Q6_Vb_vshuff_Vb(q_lo); - qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); - - q_lo = Q6_V_lo_W(qp0); - q_hi = Q6_V_hi_W(qp0); - - q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scale01); - q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scale23); - - q_lo = Q6_Vhf_equals_Vqf16(q_lo); - q_hi = Q6_Vhf_equals_Vqf16(q_hi); - - if constexpr (_IsDstAligned) { - reinterpret_cast(dst_ptr)[0] = q_lo; - reinterpret_cast(dst_ptr)[1] = q_hi; - } else { - reinterpret_cast(dst_ptr)[0] = q_lo; - reinterpret_cast(dst_ptr)[1] = q_hi; - } -} - template void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector table) { + using namespace hexagon::vec::quant; + constexpr const size_t kElemsPerVec = hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type); constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); constexpr const int qk = QUANT_BLOCK_SIZE; static_assert(qk % 2 == 0, "qk must be even"); static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); - static const auto load_masks = make_quad_block_mask(); + alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices = make_qs_load_mask(); alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices = make_scale_load_mask(); @@ -565,21 +343,41 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * d int i = 0; for (; i + 5 < nb; i += 6) { - auto qs = load_hexa_block_generic(src_ptr + i, load_masks, scale_indices); - dequantize_row_q4_0_4blocks<_IsDstAligned>(qs.val[0], qs.val[1], qs.val[2], table, dst_ptr); - dequantize_row_q4_0_2blocks<_IsDstAligned>(qs.val[3], qs.val[4], table, dst_ptr + kElemsPerVec * 2); + auto qs = load_hexa_block_generic(src_ptr + i, qs_indices, scale_indices); + auto res01 = dequantize_vec_q40_qf16_4blocks(qs.val[0], qs.val[1], qs.val[2], table); + auto res2 = dequantize_vec_q40_qf16_2blocks(qs.val[3], qs.val[4], table); + if constexpr (_IsDstAligned) { + reinterpret_cast(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]); + reinterpret_cast(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]); + reinterpret_cast(dst_ptr)[2] = Q6_Vhf_equals_Vqf16(res2); + } else { + reinterpret_cast(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]); + reinterpret_cast(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]); + reinterpret_cast(dst_ptr)[2] = Q6_Vhf_equals_Vqf16(res2); + } dst_ptr += kElemsPerVec * 3; } for (; i + 3 < nb; i += 4) { - auto qs = load_qual_block_generic(src_ptr + i, load_masks, scale_indices); - dequantize_row_q4_0_4blocks<_IsDstAligned>(qs.val[0], qs.val[1], qs.val[2], table, dst_ptr); + auto qs = load_qual_block_generic(src_ptr + i, qs_indices, scale_indices); + auto res01 = dequantize_vec_q40_qf16_4blocks(qs.val[0], qs.val[1], qs.val[2], table); + if constexpr (_IsDstAligned) { + reinterpret_cast(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]); + reinterpret_cast(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]); + } else { + reinterpret_cast(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]); + reinterpret_cast(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]); + } dst_ptr += kElemsPerVec * 2; } for (; i + 1 < nb; i += 2) { - auto qs = load_dual_block_generic(src_ptr + i, load_masks.val[0], scale_indices); - dequantize_row_q4_0_2blocks<_IsDstAligned>(qs.val[0], qs.val[1], table, dst_ptr); + auto res = load_dequant_vec_q40_qf16_2blocks(src_ptr + i, qs_indices, scale_indices, table); + if constexpr (_IsDstAligned) { + *reinterpret_cast(dst_ptr) = Q6_Vhf_equals_Vqf16(res); + } else { + *reinterpret_cast(dst_ptr) = Q6_Vhf_equals_Vqf16(res); + } dst_ptr += kElemsPerVec; } @@ -611,12 +409,8 @@ HVX_Vector load_dequant_table_q4_0() { constexpr const int kQ4ZeroPoint = 8; // zero point for q4_0 quantization static_assert(kTableSize <= hexagon::kBytesPerVector / sizeof(__fp16), "table too large"); - static const HVX_Vector result = []() -> HVX_Vector { - alignas(hexagon::kBytesPerVector) union { - HVX_Vector v; - __fp16 f16[sizeof(HVX_Vector) / sizeof(__fp16)]; - } table; - + alignas(hexagon::kBytesPerVector) static const HVX_Vector result = []() -> HVX_Vector { + alignas(hexagon::kBytesPerVector) hexagon::HVX_VectorAlias table; table.v = Q6_V_vzero(); for (int i = 0; i < kTableSize; ++i) { table.f16[i * 2] = i - kQ4ZeroPoint; // TODO: vectorize this? @@ -640,12 +434,8 @@ HVX_Vector load_dequant_table_q4_k() { constexpr const int kTableSize = 1 << 4; // 4 bits per value, 16 values static_assert(kTableSize <= hexagon::kBytesPerVector / sizeof(__fp16), "table too large"); - const static HVX_Vector result = []() -> HVX_Vector { - alignas(hexagon::kBytesPerVector) union { - HVX_Vector v; - __fp16 f16[sizeof(HVX_Vector) / sizeof(__fp16)]; - } table; - + alignas(hexagon::kBytesPerVector) static const HVX_Vector result = []() -> HVX_Vector { + alignas(hexagon::kBytesPerVector) hexagon::HVX_VectorAlias table; table.v = Q6_V_vzero(); for (int i = 0; i < kTableSize; ++i) { table.f16[i * 2] = i; // TODO: vectorize this? diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.hpp b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp index f72e1c37c0..f636189502 100644 --- a/ggml/src/ggml-qnn/npu/device/vec_ops.hpp +++ b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp @@ -19,6 +19,7 @@ using HVX_Vector_x2 = HEXAGON_pack; using HVX_Vector_x3 = HEXAGON_pack; using HVX_Vector_x4 = HEXAGON_pack; using HVX_Vector_x5 = HEXAGON_pack; +using HVX_VectorPair_x2 = HEXAGON_pack; using HVX_VectorPair_x4 = HEXAGON_pack; using HVX_VectorPred_x3 = HEXAGON_pack; @@ -203,6 +204,7 @@ inline HVX_Vector hvx_passthru(HVX_Vector src, HVX_UVector *, HVX_Vector) { #include "vec_math.inl" #include "vec_ops.inl" +#include "vec_quant.inl" namespace hexagon { @@ -268,8 +270,8 @@ inline HVX_Vector vec_dot_product_vqf32_f32_f32(const float * src0, const float inline HVX_Vector vec_dot_product_aligned_vqf32_f32_f32(const float * src0, const float * src1, size_t count) { using namespace hexagon::vec; - return vec_dot_product_aligned_impl( - src0, src1, count); + return vec_dot_product_aligned_impl(src0, src1, + count); } inline float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) { @@ -279,8 +281,8 @@ inline float vec_dot_product_f32_f32(const float * src0, const float * src1, siz inline float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) { using namespace hexagon::vec; - return vec_dot_product_aligned_impl( - src0, src1, count); + return vec_dot_product_aligned_impl(src0, src1, + count); } inline bool is_f32_f32_dot_product_aligned(const float * src0, const float * src1, size_t count) { @@ -326,13 +328,8 @@ inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0, inline HVX_Vector vec_dot_product_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { using namespace hexagon::vec; using namespace hexagon::vec::math; - return vec_dot_product_mixed_impl(src0, src1, count); + return vec_dot_product_mixed_impl(src0, src1, count); } inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0, @@ -340,37 +337,39 @@ inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t size_t count) { using namespace hexagon::vec; using namespace hexagon::vec::math; - return vec_dot_product_mix_aligned_impl(src0, src1, count); + return vec_dot_product_mix_aligned_impl(src0, src1, count); } inline float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { using namespace hexagon::vec; using namespace hexagon::vec::math; - return vec_dot_product_mixed_impl(src0, src1, count); } inline float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { using namespace hexagon::vec; using namespace hexagon::vec::math; - return vec_dot_product_mix_aligned_impl(src0, src1, count); + return vec_dot_product_mix_aligned_impl(src0, src1, count); +} + +inline HVX_Vector vec_dot_product_vqf32_q40_f32(const npu_device_block_q4_0 * src0, + const float * src1, + size_t count, + const HVX_Vector table) { + using namespace hexagon::vec; + using namespace hexagon::vec::math; + using namespace hexagon::vec::quant; + + alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices = make_qs_load_mask(); + alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices = + make_scale_load_mask(); + + return vec_dot_product_quant_impl(src0, src1, count, qs_indices, scale_indices, table); } inline bool is_f16_f32_dot_product_aligned(const npu_device_fp16_t * src0, const float * src1, size_t count) { diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.inl b/ggml/src/ggml-qnn/npu/device/vec_ops.inl index f2bb174499..4d01387108 100644 --- a/ggml/src/ggml-qnn/npu/device/vec_ops.inl +++ b/ggml/src/ggml-qnn/npu/device/vec_ops.inl @@ -4,7 +4,9 @@ #include +#include #include +#include namespace hexagon::vec { @@ -378,6 +380,156 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem return _ReduceFunc(_AddFunc(sum0, sum1)); } +template +inline _TRet vec_dot_product_quant_impl(const _TQuantElem0 * src0, + const _TElem1 * src1, + size_t count, + const HVX_Vector qs_indices, + const HVX_Vector scale_indices, + const HVX_Vector table) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem1); + + static_assert(std::is_same_v<_TQuantElem0, npu_device_block_q4_0> || + std::is_same_v<_TQuantElem0, npu_device_block_q4_k> || + std::is_same_v<_TQuantElem0, npu_device_block_q8_0>, + "Element type mismatch: _TQuantElem0 must be a supported quantization block type"); + static_assert(QUANT_BLOCK_SIZE == kElementsPerVector, + "Quant block size mismatch: QUANT_BLOCK_SIZE must be equal to kElementsPerVector"); + + assert(count % kElementsPerVector == 0 && "Count must be a multiple of kElementsPerVector"); + + const HVX_Vector kZeroV = Q6_V_vzero(); + + const _TQuantElem0 * src0_ptr = src0; + HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); + HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector; + HVX_Vector prev1 = *src1_vec_ptr++; + HVX_Vector sum = kZeroV; + + if (src1_vec_ptr_end - src1_vec_ptr > 1) { + HVX_Vector sum0 = kZeroV; + HVX_Vector sum1 = kZeroV; + + while (src1_vec_ptr_end - src1_vec_ptr > 3) { + HVX_VectorPair_x2 s01 = _DequantQuadFunc(src0_ptr, qs_indices, scale_indices, table); + HVX_Vector curr100 = src1_vec_ptr[0]; + HVX_Vector curr101 = src1_vec_ptr[1]; + HVX_Vector curr110 = src1_vec_ptr[2]; + HVX_Vector curr111 = src1_vec_ptr[3]; + + HVX_Vector l00 = Q6_V_lo_W(s01.val[0]); + HVX_Vector l10 = Q6_V_valign_VVR(curr100, prev1, (size_t) src1); + + HVX_Vector l01 = Q6_V_lo_W(s01.val[1]); + HVX_Vector l11 = Q6_V_valign_VVR(curr110, curr101, (size_t) src1); + + HVX_Vector h00 = Q6_V_hi_W(s01.val[0]); + HVX_Vector h10 = Q6_V_valign_VVR(curr101, curr100, (size_t) src1); + + HVX_Vector h01 = Q6_V_hi_W(s01.val[1]); + HVX_Vector h11 = Q6_V_valign_VVR(curr111, curr110, (size_t) src1); + + l10 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l10); + l11 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l11); + + HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l00, l10); + HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(l01, l11); + + h10 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h10); + h11 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h11); + + HVX_Vector mpy2 = Q6_Vqf32_vmpy_Vqf32Vqf32(h00, h10); + HVX_Vector mpy3 = Q6_Vqf32_vmpy_Vqf32Vqf32(h01, h11); + + prev1 = curr111; + + sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum0); + sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sum1); + + sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy2, sum0); + sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy3, sum1); + + src0_ptr += 4; + src1_vec_ptr += 4; + } + + while (src1_vec_ptr_end - src1_vec_ptr > 1) { + HVX_VectorPair s0 = _DequantDualFunc(src0_ptr, qs_indices, scale_indices, table); + HVX_Vector curr10 = src1_vec_ptr[0]; + HVX_Vector curr11 = src1_vec_ptr[1]; + + HVX_Vector l0 = Q6_V_lo_W(s0); + HVX_Vector l1 = Q6_V_valign_VVR(curr10, prev1, (size_t) src1); + + HVX_Vector h0 = Q6_V_hi_W(s0); + HVX_Vector h1 = Q6_V_valign_VVR(curr11, curr10, (size_t) src1); + + l1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l1); + h1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h1); + + HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l0, l1); + HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(h0, h1); + + prev1 = curr11; + + sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum0); + sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sum1); + + src0_ptr += 2; + src1_vec_ptr += 2; + } + + sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum0, sum1); + } + + if (src1_vec_ptr_end - src1_vec_ptr > 0) { + HVX_Vector curr1 = *src1_vec_ptr++; + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + HVX_Vector s0 = _DequantFunc(src0_ptr++, qs_indices, scale_indices, table); + s1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, s1); + + HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(s0, s1); + prev1 = curr1; + + sum = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum); + } + + if ((src1_vec_ptr_end - ((HVX_Vector *) src1)) > 0) { + // handle the last vector + // see also: + // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 + // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c + bool should_fetch_src1 = !hexagon::is_addr_aligned(src1_vec_ptr); + HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1; + src1_vec_ptr += should_fetch_src1 ? 1 : 0; + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + HVX_Vector s0 = _DequantFunc(src0_ptr, qs_indices, scale_indices, table); + s1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, s1); + + HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(s0, s1); + prev1 = curr1; + + sum = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum); + } + + return _ReduceFunc(sum); +} + template diff --git a/ggml/src/ggml-qnn/npu/device/vec_quant.inl b/ggml/src/ggml-qnn/npu/device/vec_quant.inl new file mode 100644 index 0000000000..f885242c40 --- /dev/null +++ b/ggml/src/ggml-qnn/npu/device/vec_quant.inl @@ -0,0 +1,304 @@ +#pragma once + +#include "hexagon_npu.h" + +#include + +#include + +namespace hexagon::vec::quant { + +template inline HVX_Vector load_into_vector(const _TStruct * src) { + static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load"); + + return *reinterpret_cast(&(src->*_MemberPtr)); +} + +template inline HVX_Vector load_struct_into_vector(const _TStruct * src) { + static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load"); + + return *reinterpret_cast(src); +} + +template inline HVX_Vector load_block_generic(const _TBlock & src) { + static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock), "wrong block size/padding"); + return load_into_vector<_TBlock, 1, &_TBlock::qs>(&src); +} + +template inline HVX_Vector make_scale_load_mask() { + static_assert(sizeof(_TBlock) < hexagon::kBytesPerVector, "wrong block size/padding"); + static_assert(std::is_same_v, + "scale field d must be of type npu_device_fp16_t"); + + constexpr const size_t kBytesPerScale = QUANT_BLOCK_SIZE * sizeof(_TBlock::d); + const size_t qs_start_offset = offsetof(_TBlock, d); + + hexagon::HVX_VectorAlias ret; + size_t base_i = qs_start_offset; + for (size_t ret_idx = 0; ret_idx < hexagon::kBytesPerVector; ++ret_idx) { + const auto offset = ret_idx % kBytesPerScale; + const auto i = base_i + (offset % sizeof(_TBlock::d)); + ret.u8[ret_idx] = (i & 1) ? (i / 2 + 64) : (i / 2); + if (offset == kBytesPerScale - 1) { + base_i += sizeof(_TBlock); + } + } + + return ret.v; +} + +template inline HVX_Vector make_qs_load_mask() { + static_assert(sizeof(_TBlock) < hexagon::kBytesPerVector, "wrong block size/padding"); + + const size_t qs_start_offset = offsetof(_TBlock, qs); + const size_t qs_end_offset = qs_start_offset + sizeof(_TBlock::qs); + + hexagon::HVX_VectorAlias ret; + size_t ret_idx = 0; + for (size_t i = 0; i < hexagon::kBytesPerVector; ++i) { + auto offset = i % sizeof(_TBlock); + if (offset >= qs_start_offset && offset < qs_end_offset) { + ret.u8[ret_idx++] = (i & 1) ? (i / 2 + 64) : (i / 2); + } + } + + return ret.v; +} + +template inline HVX_Vector load_dual_block_generic(const _TBlock * srcs, HVX_VectorPred mask) { + static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding"); + constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); + constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs; + + HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs); + HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale); + return Q6_V_vmux_QVV(mask, blocks, block1); +} + +template +inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs, + const HVX_Vector qs_indices, + const HVX_Vector scale_indices) { + static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding"); + + const HVX_Vector blocks = load_struct_into_vector<_TBlock, 2>(srcs); + + HVX_Vector block01 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0); + block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 2); + + HVX_Vector scale01 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks, 0); + scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2); + + if constexpr (sizeof(_TBlock) * 4 > hexagon::kBytesPerVector) { + block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 1); + block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 3); + + scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 1); + scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 3); + } + + hexagon::HVX_Vector_x2 result; + result.val[0] = block01; + result.val[1] = scale01; + return result; +} + +template +inline hexagon::HVX_Vector_x3 load_qual_block_generic(const _TBlock * srcs, + const HVX_Vector qs_indices, + const HVX_Vector scale_indices) { + static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 4, "wrong block size/padding"); + + hexagon::HVX_Vector_x3 result; + + const HVX_Vector blocks = load_struct_into_vector<_TBlock, 4>(srcs); + + { + HVX_Vector block0123 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0); + block0123 = Q6_Vb_vlut32or_VbVbVbI(block0123, qs_indices, blocks, 1); + block0123 = Q6_Vb_vlut32or_VbVbVbI(block0123, qs_indices, blocks, 2); + block0123 = Q6_Vb_vlut32or_VbVbVbI(block0123, qs_indices, blocks, 3); + + result.val[0] = block0123; + } + + { + HVX_Vector blocks23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2); + + HVX_Vector scale01 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks, 0); + scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2); + + HVX_Vector scale23 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks23, 0); + scale23 = Q6_Vb_vlut32or_VbVbVbI(scale23, scale_indices, blocks23, 2); + + result.val[1] = scale01; + result.val[2] = scale23; + } + + return result; +} + +template +inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs, + const HVX_Vector qs_indices, + const HVX_Vector scale_indices) { + static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 6, "wrong block size/padding"); + constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); + + const HVX_Vector blocks = load_struct_into_vector<_TBlock, 6>(srcs); + + hexagon::HVX_Vector_x5 result; + { + HVX_Vector block012345 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0); + block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 1); + block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 2); + block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 3); + + result.val[0] = block012345; + result.val[3] = Q6_V_vror_VR(block012345, kSizeOfQs * 4); // block45 + } + + { + HVX_Vector blocks23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2); + HVX_Vector blocks45 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 4); + + HVX_Vector scale01 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks, 0); + scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2); + + HVX_Vector scale23 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks23, 0); + scale23 = Q6_Vb_vlut32or_VbVbVbI(scale23, scale_indices, blocks23, 2); + + HVX_Vector scale45 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks45, 0); + scale45 = Q6_Vb_vlut32or_VbVbVbI(scale45, scale_indices, blocks45, 2); + + result.val[1] = scale01; + result.val[2] = scale23; + result.val[4] = scale45; + } + + return result; +} + +inline HVX_Vector dequantize_vec_q40_qf16_2blocks(HVX_Vector qs, HVX_Vector scale01, HVX_Vector table) { + constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); + + HVX_Vector q_lo = qs; + HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4); + HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2)); + + q_lo = Q6_V_lo_W(qp0); + q_lo = Q6_Vb_vshuff_Vb(q_lo); + qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); + + return Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(qp0), scale01); +} + +inline HVX_VectorPair dequantize_vec_q40_qf32_2blocks(HVX_Vector qs, HVX_Vector scale01, HVX_Vector table) { + constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); + + HVX_Vector q_lo = qs; + HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4); + HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2)); + + q_lo = Q6_V_lo_W(qp0); + q_lo = Q6_Vb_vshuff_Vb(q_lo); + qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); + + q_lo = Q6_V_lo_W(qp0); + scale01 = Q6_Vh_vshuff_Vh(scale01); + q_lo = Q6_Vh_vshuff_Vh(q_lo); // TODO: avoid vshuff here + + return Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01); +} + +inline HVX_Vector_x2 dequantize_vec_q40_qf16_4blocks(HVX_Vector qs, + HVX_Vector scale01, + HVX_Vector scale23, + HVX_Vector table) { + constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); + + HVX_Vector q_lo = qs; + HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4); + + HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4)); + + q_lo = Q6_V_lo_W(qp0); + q_lo = Q6_Vb_vshuff_Vb(q_lo); + qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); + + q_lo = Q6_V_lo_W(qp0); + q_hi = Q6_V_hi_W(qp0); + + q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scale01); + q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scale23); + + hexagon::HVX_Vector_x2 result; + result.val[0] = q_lo; + result.val[1] = q_hi; + return result; +} + +inline HVX_VectorPair_x2 dequantize_vec_q40_qf32_4blocks(HVX_Vector qs, + HVX_Vector scale01, + HVX_Vector scale23, + HVX_Vector table) { + constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); + + HVX_Vector q_lo = qs; + HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4); + + HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4)); + + q_lo = Q6_V_lo_W(qp0); + q_lo = Q6_Vb_vshuff_Vb(q_lo); + qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); + + q_lo = Q6_V_lo_W(qp0); + q_hi = Q6_V_hi_W(qp0); + + q_lo = Q6_Vh_vshuff_Vh(q_lo); + scale01 = Q6_Vh_vshuff_Vh(scale01); + + q_hi = Q6_Vh_vshuff_Vh(q_hi); + scale23 = Q6_Vh_vshuff_Vh(scale23); // TODO: avoid vshuff here + + hexagon::HVX_VectorPair_x2 result; + result.val[0] = Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01); + result.val[1] = Q6_Wqf32_vmpy_VhfVhf(q_hi, scale23); + return result; +} + +inline HVX_Vector load_dequant_vec_q40_qf32_1block(const npu_device_block_q4_0 * src, + const HVX_Vector qs_indices, + const HVX_Vector scale_indices, + const HVX_Vector table) { + // TODO: can we have a single-block version of load and dequantize? + auto qs = load_dual_block_generic(src, qs_indices, scale_indices); + return Q6_V_lo_W(dequantize_vec_q40_qf32_2blocks(qs.val[0], qs.val[1], table)); +} + +inline HVX_Vector load_dequant_vec_q40_qf16_2blocks(const npu_device_block_q4_0 * src, + const HVX_Vector qs_indices, + const HVX_Vector scale_indices, + const HVX_Vector table) { + auto qs = load_dual_block_generic(src, qs_indices, scale_indices); + return dequantize_vec_q40_qf16_2blocks(qs.val[0], qs.val[1], table); +} + +inline HVX_VectorPair load_dequant_vec_q40_qf32_2blocks(const npu_device_block_q4_0 * src, + const HVX_Vector qs_indices, + const HVX_Vector scale_indices, + const HVX_Vector table) { + auto qs = load_dual_block_generic(src, qs_indices, scale_indices); + return dequantize_vec_q40_qf32_2blocks(qs.val[0], qs.val[1], table); +} + +inline HVX_VectorPair_x2 load_dequant_vec_q40_qf32_4blocks(const npu_device_block_q4_0 * src, + const HVX_Vector qs_indices, + const HVX_Vector scale_indices, + const HVX_Vector table) { + auto qs = load_qual_block_generic(src, qs_indices, scale_indices); + return dequantize_vec_q40_qf32_4blocks(qs.val[0], qs.val[1], qs.val[2], table); +} + +} // namespace hexagon::vec::quant diff --git a/ggml/src/ggml-qnn/npu/host/buffer.cpp b/ggml/src/ggml-qnn/npu/host/buffer.cpp index ed35216d7a..bc0e394534 100644 --- a/ggml/src/ggml-qnn/npu/host/buffer.cpp +++ b/ggml/src/ggml-qnn/npu/host/buffer.cpp @@ -20,6 +20,7 @@ static hexagon::host_buffer_type * get_buffer_type_object(ggml_backend_buffer_ty } void backend_buffer_free_buffer(ggml_backend_buffer_t buffer) { + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_free_buffer", (void *) get_buffer_object(buffer)); delete get_buffer_object(buffer); } @@ -39,6 +40,7 @@ ggml_status backend_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor auto * buffer_obj = get_buffer_object(buffer); GGML_ASSERT(buffer_obj != nullptr); + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_init_tensor", (void *) buffer_obj); auto tensor_object = buffer_obj->init_tensor(tensor, device_object->get_device_handle()); if (!tensor_object) { LOG_ERROR("Failed to init tensor\n"); @@ -48,12 +50,23 @@ ggml_status backend_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor return GGML_STATUS_SUCCESS; } +void backend_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_memset_tensor.size.%zu", + (void *) get_buffer_object(buffer), size); + memset((char *) tensor->data + offset, value, size); +} + void backend_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_UNUSED(buffer); + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_set_tensor.size.%zu", + (void *) get_buffer_object(buffer), size); memcpy((char *) tensor->data + offset, data, size); } @@ -62,23 +75,27 @@ void backend_buffer_get_tensor(ggml_backend_buffer_t buffer, void * data, size_t offset, size_t size) { - GGML_UNUSED(buffer); + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_get_tensor", (void *) get_buffer_object(buffer)); memcpy(data, (const char *) tensor->data + offset, size); } bool backend_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { - GGML_UNUSED(buffer); + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_cpy_tensor", (void *) get_buffer_object(buffer)); if (ggml_backend_buffer_is_host(src->buffer)) { memcpy(dst->data, src->data, ggml_nbytes(src)); return true; } + LOG_DEBUG("[hexagon-npu][%p]backend_buffer_cpy_tensor: copy from non-host buffer not supported\n", + (void *) get_buffer_object(buffer)); return false; } void backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { auto * buffer_obj = get_buffer_object(buffer); GGML_ASSERT(buffer_obj != nullptr); + + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_clear", (void *) buffer_obj); memset(buffer_obj->get_buffer(), value, buffer_obj->get_size()); } @@ -94,7 +111,7 @@ constexpr const ggml_backend_buffer_i backend_buffer_interface = { /* .free_buffer = */ backend_buffer_free_buffer, /* .get_base = */ backend_buffer_get_base, /* .init_tensor = */ backend_buffer_init_tensor, - /* .memset_tensor = */ nullptr, + /* .memset_tensor = */ backend_buffer_memset_tensor, /* .set_tensor = */ backend_buffer_set_tensor, /* .get_tensor = */ backend_buffer_get_tensor, /* .cpy_tensor = */ backend_buffer_cpy_tensor, @@ -146,8 +163,8 @@ host_buffer::host_buffer(common::rpc_mem_ptr allocator, size_t size, uint32_t do } if (size > _allocator->get_max_alloc_size()) { - LOG_ERROR( - "[hexagon-npu]rpc memory size %zu exceeds max alloc size %zu\n", size, _allocator->get_max_alloc_size()); + LOG_ERROR("[hexagon-npu]rpc memory size %zu exceeds max alloc size %zu\n", size, + _allocator->get_max_alloc_size()); return; } @@ -161,8 +178,8 @@ host_buffer::host_buffer(common::rpc_mem_ptr allocator, size_t size, uint32_t do } host_buffer::~host_buffer() { - LOG_DEBUG( - "[hexagon-npu]destroy host_buffer(%p), size: %zu, domain_id: %d\n", (void *) _data, _size, (int) _domain_id); + LOG_DEBUG("[hexagon-npu]destroy host_buffer(%p), size: %zu, domain_id: %d\n", (void *) _data, _size, + (int) _domain_id); _tensors.clear(); if (_buffer_fd != -1) { auto ret = _allocator->fastrpc_munmap((int) _domain_id, _buffer_fd, nullptr, 0); @@ -194,17 +211,12 @@ std::shared_ptr host_buffer::init_tensor(ggml_tensor * tensor, remo return std::shared_ptr(); } - LOG_DEBUG("[hexagon-npu]mmap rpc memory(%p), fd: %d, addr: %p, size: %zu\n", - (void *) _data, - _buffer_fd, - _data, + LOG_DEBUG("[hexagon-npu]mmap rpc memory(%p), fd: %d, addr: %p, size: %zu\n", (void *) _data, _buffer_fd, _data, _size); } auto tensor_object = std::make_shared( - tensor, - _buffer_fd, - (uint64_t) (reinterpret_cast(tensor->data) - reinterpret_cast(_data)), + tensor, _buffer_fd, (uint64_t) (reinterpret_cast(tensor->data) - reinterpret_cast(_data)), device_handle); if (!tensor_object->is_valid()) { LOG_ERROR("[hexagon-npu]failed to init tensor, device handle: %p\n", (void *) device_handle); diff --git a/ggml/src/ggml-qnn/npu/host/host.cpp b/ggml/src/ggml-qnn/npu/host/host.cpp index 28c561a49f..647f7cda39 100644 --- a/ggml/src/ggml-qnn/npu/host/host.cpp +++ b/ggml/src/ggml-qnn/npu/host/host.cpp @@ -1,12 +1,13 @@ -#include -#include - #include "buffer.hpp" #include "common.hpp" #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "host_device.hpp" +#include "profiler.hpp" + +#include +#include namespace { @@ -42,7 +43,9 @@ void backend_dev_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * tota enum ggml_backend_dev_type backend_dev_get_type(ggml_backend_dev_t dev) { GGML_UNUSED(dev); - return GGML_BACKEND_DEVICE_TYPE_ACCEL; + + // TODO: figure out why the GGML_BACKEND_DEVICE_TYPE_ACCEL type will miss some ops + return GGML_BACKEND_DEVICE_TYPE_IGPU; } void backend_dev_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { @@ -62,6 +65,7 @@ ggml_backend_t backend_dev_init_backend(ggml_backend_dev_t dev, const char * par return nullptr; } + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_dev_init_backend", (void *) dev_obj); return new hexagon::npu_backend(dev); } @@ -71,8 +75,10 @@ ggml_backend_buffer_type_t backend_dev_get_buffer_type(ggml_backend_dev_t dev) { return dev_obj->get_default_buffer_type(dev); } -ggml_backend_buffer_t backend_dev_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, - size_t max_tensor_size) { +ggml_backend_buffer_t backend_dev_buffer_from_host_ptr(ggml_backend_dev_t dev, + void * ptr, + size_t size, + size_t max_tensor_size) { // TODO: should we use the device memory here? GGML_UNUSED(dev); GGML_UNUSED(max_tensor_size); @@ -86,6 +92,8 @@ bool backend_dev_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * auto * dev_obj = get_device_object(dev); GGML_ASSERT(dev_obj != nullptr); + + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_dev_supports_op", (void *) dev_obj); return dev_obj->supports_op(op); } @@ -96,6 +104,7 @@ bool backend_dev_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_ auto * dev_obj = get_device_object(dev); GGML_ASSERT(dev_obj != nullptr); + return dev_obj->supports_buft(buft); } @@ -106,6 +115,8 @@ bool backend_dev_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * o auto * dev_obj = get_device_object(dev); GGML_ASSERT(dev_obj != nullptr); + + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_dev_offload_op", (void *) dev_obj); return dev_obj->offload_op(op); } diff --git a/ggml/src/ggml-qnn/npu/host/host_device.cpp b/ggml/src/ggml-qnn/npu/host/host_device.cpp index 5e9f518879..e9ddfd4c43 100644 --- a/ggml/src/ggml-qnn/npu/host/host_device.cpp +++ b/ggml/src/ggml-qnn/npu/host/host_device.cpp @@ -143,17 +143,18 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) { } if (op->op == GGML_OP_VIEW || op->op == GGML_OP_RESHAPE || op->op == GGML_OP_PERMUTE) { + LOG_DEBUG("[%s]view/reshape/permute op is always supported\n", get_name()); return true; } if (type_to_npu_type(op->type) == NPU_DATA_TYPE_COUNT) { - LOG_DEBUG("[%s]Unsupported op tensor type: %s\n", get_name(), ggml_type_name(op->type)); + LOG_DEBUG("[%s][%s]Unsupported op tensor type: %s\n", get_name(), ggml_get_name(op), ggml_type_name(op->type)); return false; } auto npu_op = op_to_npu_op(op->op); if (npu_op == NPU_OP_COUNT) { - LOG_DEBUG("[%s]Unsupported op: %s\n", get_name(), ggml_op_desc(op)); + LOG_DEBUG("[%s][%s]Unsupported op: %s\n", get_name(), ggml_get_name(op), ggml_op_desc(op)); return false; } @@ -206,14 +207,8 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) { #ifndef NDEBUG auto * src0_type = i ? ggml_type_name(op->src[0]->type) : "null"; auto * src1_type = (i > 1) ? ggml_type_name(op->src[1]->type) : "null"; - LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", - get_name(), - ggml_op_desc(op), - ggml_type_name(op->type), - src0_type, - src1_type, - ret, - supported); + LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", get_name(), ggml_op_desc(op), + ggml_type_name(op->type), src0_type, src1_type, ret, supported); #endif return false; } @@ -258,8 +253,8 @@ bool npu_device::init_device_lib() { } if (err != AEE_SUCCESS) { - LOG_ERROR( - "[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err, device_lib_uri.c_str()); + LOG_ERROR("[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err, + device_lib_uri.c_str()); _device_handle = 0; return false; } @@ -289,26 +284,16 @@ bool npu_device::supports_op(const ggml_tensor * op) { if (op->op != GGML_OP_NONE && op->op != GGML_OP_VIEW && op->op != GGML_OP_RESHAPE && op->op != GGML_OP_PERMUTE) { _supported_op++; - LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n", - get_name(), - ggml_op_desc(op), - ggml_get_name(op), - op_desc, - _supported_op.load(), - _unsupported_op.load()); + LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op), + ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load()); } return true; } _unsupported_op++; - LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n", - get_name(), - ggml_op_desc(op), - ggml_get_name(op), - op_desc, - _supported_op.load(), - _unsupported_op.load()); + LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op), + ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load()); return false; } #else