diff --git a/ggml/src/ggml-qnn/npu/device/device.cpp b/ggml/src/ggml-qnn/npu/device/device.cpp index db987217fa..f073917112 100644 --- a/ggml/src/ggml-qnn/npu/device/device.cpp +++ b/ggml/src/ggml-qnn/npu/device/device.cpp @@ -1,10 +1,4 @@ -#include -#include -#include - -#include - #include "graph.hpp" #include "hexagon_npu.h" #include "op_impl.hpp" @@ -14,6 +8,12 @@ #include "type_traits.hpp" #include "util.hpp" +#include +#include +#include + +#include + namespace { struct npu_device_context { @@ -130,8 +130,12 @@ AEEResult npu_device_device_get_alignment(remote_handle64 _h, uint32_t * alignme return AEE_SUCCESS; } -AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op op, const npu_device_tensor_spec * dst, - const npu_device_tensor_spec * srcs, int srcsLen, boolean * is_supported) { +AEEResult npu_device_device_support_op(remote_handle64 _h, + const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + int srcsLen, + boolean * is_supported) { NPU_UNUSED(_h); if (!srcs || srcsLen <= 0 || !dst || !is_supported) { @@ -139,19 +143,21 @@ AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op return AEE_EINVARGS; } - *is_supported = hexagon::support_op(op, dst, srcs, srcsLen); + *is_supported = hexagon::support_op(op_spec, dst, srcs, srcsLen); return AEE_SUCCESS; } -AEEResult npu_device_tensor_init(remote_handle64 _h, const npu_device_tensor_config * info, - npu_device_tensor_handle_t * tensor_handle) { +AEEResult npu_device_tensor_init(remote_handle64 _h, + const npu_device_tensor_config * info, + npu_device_tensor_handle_t * tensor_handle) { NPU_UNUSED(_h); auto * tensor = new hexagon::tensor(*info); *tensor_handle = tensor_to_handle(tensor); return AEE_SUCCESS; } -AEEResult npu_device_tensor_update_params(remote_handle64 _h, npu_device_tensor_handle_t tensor_handle, +AEEResult npu_device_tensor_update_params(remote_handle64 _h, + npu_device_tensor_handle_t tensor_handle, const npu_device_tensor_update_config * config) { NPU_UNUSED(_h); auto * tensor = tensor_from_handle(tensor_handle); @@ -174,8 +180,9 @@ AEEResult npu_device_tensor_free(remote_handle64 _h, npu_device_tensor_handle_t return AEE_SUCCESS; } -AEEResult npu_device_tensors_free(remote_handle64 _h, const npu_device_tensor_handle_t * tensor_handles, - int tensor_handlesLen) { +AEEResult npu_device_tensors_free(remote_handle64 _h, + const npu_device_tensor_handle_t * tensor_handles, + int tensor_handlesLen) { NPU_UNUSED(_h); if (!tensor_handles || tensor_handlesLen < 0) { DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments"); @@ -201,8 +208,10 @@ AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t * return AEE_SUCCESS; } -AEEResult npu_device_graph_set_tensor(remote_handle64 _h, npu_device_graph_handle_t graph_handle, - const npu_device_tensor_handle_t * tensor_handles, int tensor_handlesLen) { +AEEResult npu_device_graph_set_tensor(remote_handle64 _h, + npu_device_graph_handle_t graph_handle, + const npu_device_tensor_handle_t * tensor_handles, + int tensor_handlesLen) { NPU_UNUSED(_h); auto * graph = graph_from_handle(graph_handle); if (!graph || !tensor_handles || tensor_handlesLen <= 0) { @@ -213,7 +222,8 @@ AEEResult npu_device_graph_set_tensor(remote_handle64 _h, npu_device_graph_handl return AEE_SUCCESS; } -AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h, npu_device_graph_handle_t graph_handle, +AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h, + npu_device_graph_handle_t graph_handle, const npu_device_tensor_handle_t * tensor_handles, int tensor_handlesLen, const npu_device_tensor_update_config * tensor_params, diff --git a/ggml/src/ggml-qnn/npu/device/graph.cpp b/ggml/src/ggml-qnn/npu/device/graph.cpp index 5bc14a0aca..c963ef966e 100644 --- a/ggml/src/ggml-qnn/npu/device/graph.cpp +++ b/ggml/src/ggml-qnn/npu/device/graph.cpp @@ -1,12 +1,12 @@ #include "graph.hpp" -#include - #include "op_impl.hpp" #include "util.hpp" #include "vtcm_mem.hpp" +#include + namespace hexagon { graph::graph() noexcept { @@ -30,8 +30,12 @@ void graph::set_tensor(const npu_device_tensor_handle_t * tensors, int tensor_co for (int i = 0; i < tensor_count; ++i) { auto * tensor_obj = reinterpret_cast(tensors[i]); _tensors[i] = tensor_obj; - DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n", (void *) this, i, (void *) tensor_obj, - (void *) tensor_obj->get_src(0), (void *) tensor_obj->get_src(1), + DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n", + (void *) this, + i, + (void *) tensor_obj, + (void *) tensor_obj->get_src(0), + (void *) tensor_obj->get_src(1), op_get_name(tensor_obj->get_op())); } @@ -64,8 +68,9 @@ bool graph::compute(default_thread_pool * thread_pool, const float * f16_to_f32_ return true; } -void graph::thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params, - void * graph) { +void graph::thread_pool_task(default_thread_pool * pool, + default_thread_pool::thread_params * thread_params, + void * graph) { reinterpret_cast(graph)->compute_impl(pool, thread_params); } @@ -86,8 +91,11 @@ void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread const bool should_sync = requires_thread_barrier(op); if (pool && should_sync && i < _tensor_count - 1) { - DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this, - params.get_thread_index(), i, _tensor_count); + DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", + (void *) this, + params.get_thread_index(), + i, + _tensor_count); pool->sync_thread(); } } diff --git a/ggml/src/ggml-qnn/npu/device/graph.hpp b/ggml/src/ggml-qnn/npu/device/graph.hpp index 36cf5bfc5c..05e6852557 100644 --- a/ggml/src/ggml-qnn/npu/device/graph.hpp +++ b/ggml/src/ggml-qnn/npu/device/graph.hpp @@ -1,11 +1,11 @@ #pragma once -#include - #include "hexagon_npu.h" #include "tensor.hpp" #include "thread_pool.hpp" +#include + namespace hexagon { class graph { @@ -20,8 +20,9 @@ class graph { bool compute(default_thread_pool * thread_pool, const float * f16_to_f32_table); private: - static void thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params, - void * graph); + static void thread_pool_task(default_thread_pool * pool, + default_thread_pool::thread_params * thread_params, + void * graph); void compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params); std::unique_ptr _tensors; diff --git a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp index 5beea614a3..b721d06f37 100644 --- a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp @@ -14,15 +14,20 @@ inline float f16_to_f32(const npu_device_fp16_t src) { // From: ggml/src/ggml-cpu/ops.cpp template -void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k, - const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) { +void flash_attn_impl(hexagon::tensor * out, + const hexagon::tensor * q, + const hexagon::tensor * k, + const hexagon::tensor * v, + const hexagon::tensor * mask, + hexagon::compute_params * params) { static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count"); constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32; if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) { DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n", - hexagon::get_type_name(k->get_type()), hexagon::get_type_name(v->get_type())); + hexagon::get_type_name(k->get_type()), + hexagon::get_type_name(v->get_type())); return; } @@ -80,7 +85,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1); uint8_t * dst_ptr = out->get_write_buffer(); if (!dst_ptr) { - DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out, + DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", + (void *) out, hexagon::get_type_name(out->get_type())); return; } @@ -118,7 +124,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex const npu_device_fp16_t * mp = mask_ptr ? reinterpret_cast(mask_ptr + iq1 * mask->get_nb(1) + - (iq3 % mask->get_ne(2)) * mask->get_nb(2)) : + (iq2 % mask->get_ne(2)) * mask->get_nb(2) + + (iq3 % mask->get_ne(3)) * mask->get_nb(3)) : nullptr; // k indices @@ -251,8 +258,8 @@ bool flash_attn_f32(tensor * out, compute_params * params) { const auto * v = out->get_src(2); const auto * mask = out->get_src(3); if (!q || !k || !v || !mask) { - DEVICE_LOG_DEBUG("invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v, - (void *) mask); + DEVICE_LOG_DEBUG( + "invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v, (void *) mask); return false; } @@ -264,8 +271,11 @@ bool flash_attn_f32(tensor * out, compute_params * params) { return true; } -bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, - const npu_device_tensor_spec * srcs, size_t src_len) { +bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len) { + const auto op = op_spec->op; if (op != NPU_OP_FLASH_ATTN) { DEVICE_LOG_DEBUG("op is not NPU_OP_FLASH_ATTN: %d\n", op); return false; @@ -295,7 +305,9 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp const auto * v = &srcs[2]; if (v->type != k->type) { // TODO: support more v types - DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n", op_get_name(op), get_type_name(v->type), + DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n", + op_get_name(op), + get_type_name(v->type), get_type_name(k->type)); return false; } @@ -310,28 +322,42 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp DEVICE_LOG_DEBUG( "[%s]dst shape does not match q and v: dst ne: %ld, %ld, %ld, %ld, q ne: %ld, %ld, %ld, %ld, " "v ne: %ld, %ld, %ld, %ld\n", - op_get_name(op), dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3], - v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + op_get_name(op), + dst->ne[0], + dst->ne[1], + dst->ne[2], + dst->ne[3], + q->ne[0], + q->ne[1], + q->ne[2], + q->ne[3], + v->ne[0], + v->ne[1], + v->ne[2], + v->ne[3]); return false; } if (is_transposed_or_permuted(dst->nb)) { - DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n", op_get_name(op), - dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]); + DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n", + op_get_name(op), + dst->nb[0], + dst->nb[1], + dst->nb[2], + dst->nb[3]); return false; } if (q->ne[0] != k->ne[0]) { DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n", - op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2], - k->ne[3]); - return false; - } - - if (q->ne[2] != k->ne[2] || q->ne[3] != k->ne[3] || q->ne[3] != 1) { - // TODO: add broadcast support - DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n", - op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2], + op_get_name(op), + q->ne[0], + q->ne[1], + q->ne[2], + q->ne[3], + k->ne[0], + k->ne[1], + k->ne[2], k->ne[3]); return false; } diff --git a/ggml/src/ggml-qnn/npu/device/op_flash_attn.hpp b/ggml/src/ggml-qnn/npu/device/op_flash_attn.hpp index 63d6d09d54..071904228c 100644 --- a/ggml/src/ggml-qnn/npu/device/op_flash_attn.hpp +++ b/ggml/src/ggml-qnn/npu/device/op_flash_attn.hpp @@ -5,7 +5,9 @@ namespace hexagon { bool flash_attn_f32(tensor * out, compute_params * params); -bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, - const npu_device_tensor_spec * srcs, size_t src_len); +bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op_impl.cpp b/ggml/src/ggml-qnn/npu/device/op_impl.cpp index a794a8b750..647fdd347f 100644 --- a/ggml/src/ggml-qnn/npu/device/op_impl.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_impl.cpp @@ -2,20 +2,20 @@ #include "op_impl.hpp" -#include - #include "op_flash_attn.hpp" #include "op_mul_mat.hpp" #include "op_rope.hpp" #include "type_traits.hpp" #include "vec_ops.hpp" +#include + namespace { template -inline void vec_op_f32_f32(const float * src0, const float * src1, size_t count, float * dst) { +inline void vec_op_f32_f32(const float * src0, const float * src1, float * dst, size_t count) { using namespace hexagon::vec; - vec_trans_op_impl<_OpBinaryTransform, float>(src0, src1, count, dst); + vec_trans_impl<_OpBinaryTransform, float>(src0, src1, dst, count); } inline HVX_Vector vadd_f32_f32(HVX_Vector a, HVX_Vector b) { @@ -31,10 +31,12 @@ inline HVX_Vector vmul_f32_f32(HVX_Vector a, HVX_Vector b) { } template -inline void vec_op_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count, - npu_device_fp16_t * dst) { +inline void vec_op_f16_f16(const npu_device_fp16_t * src0, + const npu_device_fp16_t * src1, + npu_device_fp16_t * dst, + size_t count) { using namespace hexagon::vec; - vec_trans_op_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, count, dst); + vec_trans_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, dst, count); } inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) { @@ -53,12 +55,17 @@ inline HVX_Vector vmul_f16_f16(HVX_Vector a, HVX_Vector b) { template struct get_data_type {}; -template struct get_data_type { +template struct get_data_type { + using type = _TyData; +}; + +template +struct get_data_type { using type = _TyData; }; template -struct get_data_type { +struct get_data_type { using type = _TyData; using param_type = typename std::remove_cv::type>::type; }; @@ -85,7 +92,8 @@ template bool element_wise_op(hexagon::tensor * out, hexagon::co uint8_t * dst_ptr = out->get_write_buffer(); if (!dst_ptr) { - DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out, + DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n", + (void *) out, hexagon::get_type_name(out->get_type())); return false; } @@ -119,16 +127,21 @@ template bool element_wise_op(hexagon::tensor * out, hexagon::co hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes); } - _RowFunc(reinterpret_cast(src0_row), reinterpret_cast(src1_row), - static_cast(out->get_ne(0)), reinterpret_cast(dst_row)); + _RowFunc(reinterpret_cast(src0_row), + reinterpret_cast(src1_row), + reinterpret_cast(dst_row), + static_cast(out->get_ne(0))); } out->release_write_buffer(); // mark the output tensor as modified return true; } -bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, - const npu_device_tensor_spec * srcs, size_t src_len) { +bool is_element_wise_op_supported(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len) { + const auto op = op_spec->op; if (op != NPU_OP_ADD && op != NPU_OP_SUB && op != NPU_OP_MUL) { DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op)); return false; @@ -142,27 +155,31 @@ bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tens const auto & src0 = srcs[0]; const auto & src1 = srcs[1]; if (dst->type != src0.type || dst->type != src1.type) { - DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op), - hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type)); + DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", + hexagon::op_get_name(op), + hexagon::get_type_name(src0.type), + hexagon::get_type_name(dst->type)); return false; } if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) { - DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), - hexagon::get_type_name(dst->type)); + DEVICE_LOG_DEBUG( + "[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type)); return false; } // TODO: fix FP16 add/sub if (dst->type == NPU_DATA_TYPE_F16 && op != NPU_OP_MUL) { - DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), - hexagon::get_type_name(dst->type)); + DEVICE_LOG_DEBUG( + "[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type)); return false; } if (src0.ne[0] != src1.ne[0]) { - DEVICE_LOG_DEBUG("[%s]src0.ne[0] and src1.ne[0] not match: %ld vs %ld\n", hexagon::op_get_name(op), - (long) src0.ne[0], (long) src1.ne[0]); + DEVICE_LOG_DEBUG("[%s]src0.ne[0] and src1.ne[0] not match: %ld vs %ld\n", + hexagon::op_get_name(op), + (long) src0.ne[0], + (long) src1.ne[0]); return false; } @@ -174,7 +191,7 @@ bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tens return true; } -void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) { +void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) { constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(float); HVX_Vector * src_vec_ptr = ((HVX_Vector *) src); @@ -206,7 +223,7 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) { (leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev; curr = Q6_V_valign_VVR(curr, prev, (size_t) src); sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, - Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes)); + Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes)); } const float mean = hexagon::vec_reduction_f32_qf32(sum) / count; // TODO: figure out how to do division in vector @@ -231,7 +248,8 @@ template bool unary_op(hexagon::tensor * out, hexagon::compute_p auto * dst_ptr = out->get_write_buffer(); if (!dst_ptr) { - DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out, + DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n", + (void *) out, hexagon::get_type_name(out->get_type())); return false; } @@ -259,16 +277,21 @@ template bool unary_op(hexagon::tensor * out, hexagon::compute_p hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes); } - _RowFunc(reinterpret_cast(src0_row), static_cast(out->get_ne(0)), param, - reinterpret_cast(dst_row)); + _RowFunc(reinterpret_cast(src0_row), + reinterpret_cast(dst_row), + static_cast(out->get_ne(0)), + param); } out->release_write_buffer(); // mark the output tensor as modified return true; } -bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, - const npu_device_tensor_spec * srcs, size_t src_len) { +bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len) { + const auto op = op_spec->op; if (op != NPU_OP_RMS_NORM) { DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op)); return false; @@ -281,14 +304,16 @@ bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec const auto & src0 = srcs[0]; if (dst->type != src0.type) { - DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op), - hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type)); + DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", + hexagon::op_get_name(op), + hexagon::get_type_name(src0.type), + hexagon::get_type_name(dst->type)); return false; } if (dst->type != NPU_DATA_TYPE_F32) { - DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), - hexagon::get_type_name(dst->type)); + DEVICE_LOG_DEBUG( + "[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type)); return false; } @@ -300,6 +325,171 @@ bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec return true; } +inline void glu_vec_op_f32_f32(const float * src0, + const float * src1, + float * dst, + size_t count, + hexagon::HVX_VectorPair_x4 coeff) { + using namespace hexagon::vec; + vec_trans_with_param_impl( + src0, src1, dst, count, coeff); +} + +inline void glu_vec_op_f16_f16(const npu_device_fp16_t * src0, + const npu_device_fp16_t * src1, + npu_device_fp16_t * dst, + size_t count, + hexagon::HVX_VectorPair_x4 coeff) { + using namespace hexagon::vec; + vec_trans_with_param_impl( + src0, src1, dst, count, coeff); +} + +template +bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) { + using data_type = typename get_data_type::type; + static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "element_wise_op requires max dims 4"); + + if (!out) { + return false; + } + + const bool has_src1 = out->get_src(1) != nullptr; + auto * src0 = out->get_src(0); + auto * src1 = has_src1 ? out->get_src(1) : src0; + if (!src0 || !src1) { + return true; // skip if no src + } + + const auto total_cols = has_src1 ? src0->get_ne(0) : src0->get_ne(0) / 2; + if (out->get_ne(0) != total_cols) { + DEVICE_LOG_ERROR("out.ne[0] (%ld) != total_cols (%d)\n", (long) out->get_ne(0), (int) total_cols); + return false; + } + + auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1); + const auto rows_per_cube = out->get_ne(2) * out->get_ne(1); + const auto start_end = params->get_work_slice(total_rows); + if (start_end.first >= start_end.second) { + return true; + } + + uint8_t * dst_ptr = out->get_write_buffer(); + if (!dst_ptr) { + DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n", + (void *) out, + hexagon::get_type_name(out->get_type())); + return false; + } + + const int32_t swapped = out->get_op_param(1); + const uint8_t * src0_ptr = src0->get_read_buffer(); + const uint8_t * src1_ptr = has_src1 ? src1->get_read_buffer() : (src0_ptr + total_cols * sizeof(data_type)); + if (swapped) { + std::swap(src0_ptr, src1_ptr); + } + + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index()); + + auto coeff = _CoeffLoadFunc(); + const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type); + for (int64_t ir = start_end.first; ir < start_end.second; ++ir) { + const auto i03 = ir / rows_per_cube; + const auto i02 = ir / out->get_ne(1) - i03 * out->get_ne(2); + const auto i01 = ir % out->get_ne(1); // TODO: should we use divide instead of mod? + const auto i13 = i03 % src1->get_ne(3); + const auto i12 = i02 % src1->get_ne(2); + const auto i11 = i01 % src1->get_ne(1); + + auto * src1_plane = src1_ptr + i13 * src1->get_nb(3) + i12 * src1->get_nb(2); + auto * src0_row = src0_ptr + i03 * src0->get_nb(3) + i02 * src0->get_nb(2) + i01 * src0->get_nb(1); + auto * src1_row = src1_plane + i11 * src1->get_nb(1); + auto * dst_row = dst_ptr + i03 * out->get_nb(3) + i02 * out->get_nb(2) + i01 * out->get_nb(1); + if (ir + 1 < start_end.second) { + hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes); + hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes); + } + + _GluRowFunc(reinterpret_cast(src0_row), + reinterpret_cast(src1_row), + reinterpret_cast(dst_row), + static_cast(total_cols), + coeff); + } + + out->release_write_buffer(); // mark the output tensor as modified + return true; +} + +template +bool glu_compute(hexagon::tensor * out, hexagon::compute_params * params) { + using namespace hexagon::vec::math; + + if (out->get_op_param(0) != NPU_GLU_OP_SWIGLU) { + DEVICE_LOG_ERROR("Invalid GLU op type: %d\n", out->get_op_param(0)); + return false; + } + + if (out->get_type() != _DataType) { + DEVICE_LOG_ERROR("GLU op type mismatch: %s vs %s\n", + hexagon::get_type_name(out->get_type()), + hexagon::get_type_name(_DataType)); + return false; + } + + if constexpr (_DataType == NPU_DATA_TYPE_F32) { + return glu_impl(out, params); + } else if constexpr (_DataType == NPU_DATA_TYPE_F16) { + return glu_impl(out, params); + } + + DEVICE_LOG_ERROR("Unsupported GLU data type: %s\n", hexagon::get_type_name(out->get_type())); + return true; +} + +bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len) { + const auto op = op_spec->op; + if (op != NPU_OP_GLU) { + DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op)); + return false; + } + + if (op_spec->params[0] != NPU_GLU_OP_SWIGLU) { + DEVICE_LOG_DEBUG("[%s]unsupported GLU op type: %d\n", hexagon::op_get_name(op), op_spec->params[0]); + return false; + } + + if (!dst || !srcs || src_len < 1) { + DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op)); + return false; + } + + const auto & src0 = srcs[0]; + if (dst->type != src0.type) { + DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", + hexagon::op_get_name(op), + hexagon::get_type_name(src0.type), + hexagon::get_type_name(dst->type)); + return false; + } + + if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) { + DEVICE_LOG_DEBUG( + "[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type)); + return false; + } + + if (!hexagon::is_same_shape(src0, *dst)) { + DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op)); + return false; + } + + return false; // TODO: fix: for some input hexagon intrinsics will generate nan instead of inf. +} + struct op_capabilities { npu_device_tensor_op op; hexagon::op_is_supported_func_type is_supported; @@ -341,7 +531,7 @@ constexpr const op_capabilities kOpCapabilities[] = { { unary_op, // NPU_DATA_TYPE_F32 nullptr, // NPU_DATA_TYPE_F16 - }, false, // requires_thread_barrier + }, false, // requires_thread_barrier }, { NPU_OP_FLASH_ATTN,hexagon::is_flash_attn_supported, @@ -357,6 +547,13 @@ constexpr const op_capabilities kOpCapabilities[] = { nullptr, // NPU_DATA_TYPE_F16 }, false, // requires_thread_barrier }, + { + NPU_OP_GLU, is_glu_op_supported, + { + glu_compute, // NPU_DATA_TYPE_F32 + glu_compute, // NPU_DATA_TYPE_F16 + }, false, // requires_thread_barrier + }, }; static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32, @@ -370,6 +567,7 @@ static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM, static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN, "kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN"); static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE"); +static_assert(kOpCapabilities[NPU_OP_GLU].op == NPU_OP_GLU, "kOpArray[NPU_OP_GLU].op != NPU_OP_GLU"); hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) { if (op >= NPU_OP_COUNT) { @@ -395,17 +593,25 @@ bool requires_thread_barrier(npu_device_tensor_op op) { return kOpCapabilities[op].requires_thread_barrier; } -bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, - size_t src_len) { - auto is_supported_func = kOpCapabilities[op].is_supported; - if (!is_supported_func || !is_supported_func(op, dst, srcs, src_len)) { +bool support_op(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len) { + if (!op_spec) { + DEVICE_LOG_ERROR("[hexagon-npu]invalid op_spec\n"); + return false; + } + + const auto op = op_spec->op; + auto is_supported_func = kOpCapabilities[op].is_supported; + if (!is_supported_func || !is_supported_func(op_spec, dst, srcs, src_len)) { DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op)); return false; } if (get_compute_func_impl(op, dst->type) == nullptr) { - DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op), - get_type_name(dst->type)); + DEVICE_LOG_DEBUG( + "[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op), get_type_name(dst->type)); return false; } diff --git a/ggml/src/ggml-qnn/npu/device/op_impl.hpp b/ggml/src/ggml-qnn/npu/device/op_impl.hpp index 709d493428..7491ea975b 100644 --- a/ggml/src/ggml-qnn/npu/device/op_impl.hpp +++ b/ggml/src/ggml-qnn/npu/device/op_impl.hpp @@ -8,7 +8,9 @@ compute_func_type get_compute_func(tensor * dst); bool requires_thread_barrier(npu_device_tensor_op op); -bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, - size_t src_len); +bool support_op(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp b/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp index e7ca2ea440..41bf2c7838 100644 --- a/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp @@ -29,7 +29,9 @@ template <> struct convert_vector { }; template -void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst, +void mul_mat_impl(hexagon::tensor * src0, + hexagon::tensor * src1, + hexagon::tensor * dst, hexagon::compute_params * params) { using data_type0 = typename get_data_type::data_type0; using data_type1 = typename get_data_type::data_type1; @@ -62,8 +64,12 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso DEVICE_LOG_DEBUG( "mul_mat_impl: no work to do, start_end_plane: (%ld, %ld), start_end_row: (%ld, %ld), " "start_end_element: (%ld, %ld)\n", - start_end_plane.first, start_end_plane.second, start_end_row.first, start_end_row.second, - start_end_element.first, start_end_element.second); + start_end_plane.first, + start_end_plane.second, + start_end_row.first, + start_end_row.second, + start_end_element.first, + start_end_element.second); return; } @@ -81,7 +87,9 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso DEVICE_LOG_ERROR( "mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, " "src0_actual_row_size: %zu, will fallback to mem cache\n", - src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size); + src0_plane_cache_size, + src0_plane_slice_row_count, + src0_actual_row_size); return; } } @@ -89,7 +97,10 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso DEVICE_LOG_DEBUG( "mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: " "%p(%zu)\n", - src0_actual_row_size, src0_plane_slice_row_count, _ShouldCacheSrc0, (void *) src0_plane_cache_ptr, + src0_actual_row_size, + src0_plane_slice_row_count, + _ShouldCacheSrc0, + (void *) src0_plane_cache_ptr, src0_plane_cache_size); const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0); @@ -99,7 +110,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso uint8_t * dst_ptr = dst->get_write_buffer(); if (!dst_ptr) { - DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst, + DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n", + (void *) dst, hexagon::get_type_name(dst->get_type())); return; } @@ -114,16 +126,17 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2); for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second; col_idx += src0_plane_slice_row_count) { + const uint8_t * src0_plane = + src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1); + hexagon::l2fetch_row(src0_plane, src0->get_nb(1)); + const int64_t actual_row_count = std::min(src0_plane_slice_row_count, start_end_element.second - col_idx); // number of rows in this slice - const uint8_t * src0_plane = - src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1); if constexpr (_ShouldCacheSrc0) { if (last_cached_plane_ptr != src0_plane) { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant); - hexagon::l2fetch_row(src0_plane, src0->get_nb(1)); for (int64_t ir = 0; ir < actual_row_count; ir++) { auto * src0_row = src0_plane + ir * src0->get_nb(1); if (ir + 1 < actual_row_count) { @@ -131,7 +144,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso } auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size; - dequantize_row_func(src0_row, reinterpret_cast(cached_row_ptr), + dequantize_row_func(src0_row, + reinterpret_cast(cached_row_ptr), src0->get_ne(0)); } @@ -158,12 +172,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso // TODO: figure dst how to handle a entire row auto res0 = _DotFunc(reinterpret_cast(src0_row), - reinterpret_cast(src1_row), (size_t) src0->get_ne(0)); - - { - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store); - dst_row[i0] = convert_vector::convert(res0); - } + reinterpret_cast(src1_row), + (size_t) src0->get_ne(0)); if (should_fetch_src0_row && i0 + 2 < actual_row_count) { hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes); @@ -171,10 +181,12 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso // TODO: figure dst how to handle a entire row auto res1 = _DotFunc(reinterpret_cast(src0_row + src0_actual_row_size), - reinterpret_cast(src1_row), (size_t) src0->get_ne(0)); + reinterpret_cast(src1_row), + (size_t) src0->get_ne(0)); { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store); + dst_row[i0] = convert_vector::convert(res0); dst_row[i0 + 1] = convert_vector::convert(res1); } } @@ -186,7 +198,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso if (i0 < actual_row_count) { auto * src0_row = src0_plane + i0 * src0_actual_row_size; auto res = _DotFunc(reinterpret_cast(src0_row), - reinterpret_cast(src1_row), (size_t) src0->get_ne(0)); + reinterpret_cast(src1_row), + (size_t) src0->get_ne(0)); DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store); dst_row[i0] = convert_vector::convert(res); } @@ -197,19 +210,194 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso dst->release_write_buffer(); // mark the output tensor as modified } -bool is_row_size_cacheable(const npu_device_tensor_spec & src) { - const auto & type_traits = hexagon::get_type_traits(src.type); - if (type_traits.to_float == nullptr) { - DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) cannot be cached, to_float is null\n", - hexagon::get_type_name(src.type)); +template +void mul_mat_gemv_impl(hexagon::tensor * src0, + hexagon::tensor * src1, + hexagon::tensor * dst, + hexagon::compute_params * params) { + using data_type0 = typename get_data_type::data_type0; + using data_type1 = typename get_data_type::data_type1; + + const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0); + auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; + if (_ShouldCacheSrc0 && dequantize_row_func == nullptr) { + DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type()); + return; + } + + auto start_end_element = std::pair{ 0, dst->get_ne(0) }; + if (dst->get_ne(0) >= params->get_thread_count()) { + start_end_element = params->get_work_slice(dst->get_ne(0)); + } else { + DEVICE_LOG_ERROR("Unsupported src1 tensor shape for gemv: %s, ne: %ldx%ldx%ldx%ld\n", + hexagon::get_type_name(src1->get_type()), + src1->get_ne(0), + src1->get_ne(1), + src1->get_ne(2), + src1->get_ne(3)); + return; + } + + if (start_end_element.second <= start_end_element.first) { + DEVICE_LOG_DEBUG( + "mul_mat_impl: no work to do, start_end_plane: [0, 1), start_end_row: [0, 1), " + "start_end_element: [%ld, %ld)\n", + start_end_element.first, + start_end_element.second); + return; + } + + // cache the src0 plane in VTCM + size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first; + size_t src0_plane_cache_size = 0; + uint8_t * src0_plane_cache_ptr = nullptr; + const auto src1_actual_row_size = hexagon::get_aligned_size(src1->get_nb(1)); + uint8_t * src1_row_cache_ptr = nullptr; + if constexpr (_ShouldCacheSrc0) { + src0_plane_slice_row_count = std::min( + (params->get_vtcm_quota_size() - src1_actual_row_size) / src0_actual_row_size, src0_plane_slice_row_count); + src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count; + src0_plane_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size + src1_actual_row_size); + if (src0_plane_cache_ptr == nullptr) { + DEVICE_LOG_ERROR( + "mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, " + "src0_actual_row_size: %zu, will fallback to mem cache\n", + src0_plane_cache_size, + src0_plane_slice_row_count, + src0_actual_row_size); + return; + } + + src1_row_cache_ptr = src0_plane_cache_ptr + src0_plane_cache_size; + } else { + src1_row_cache_ptr = params->get_vtcm_cache(src1_actual_row_size); + if (src1_row_cache_ptr == nullptr) { + DEVICE_LOG_ERROR("mul_mat_impl: failed to get VTCM cache for src1, size: %zu\n", src1_actual_row_size); + return; + } + } + + DEVICE_LOG_DEBUG( + "mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: " + "%p(%zu)\n", + src0_actual_row_size, + src0_plane_slice_row_count, + _ShouldCacheSrc0, + (void *) src0_plane_cache_ptr, + src0_plane_cache_size); + + const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0); + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat); + + uint8_t * dst_ptr = dst->get_write_buffer(); + if (!dst_ptr) { + DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n", + (void *) dst, + hexagon::get_type_name(dst->get_type())); + return; + } + + constexpr bool should_fetch_src0_row = !_ShouldCacheSrc0; + const uint8_t * src0_ptr = src0->get_read_buffer(); + const uint8_t * src1_ptr = src1->get_read_buffer(); + + { + memcpy(src1_row_cache_ptr, src1_ptr, src1->get_ne(0) * sizeof(data_type1)); + src1_ptr = src1_row_cache_ptr; + } + + { + for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second; + col_idx += src0_plane_slice_row_count) { + const uint8_t * src0_plane = src0_ptr + col_idx * src0->get_nb(1); + hexagon::l2fetch_row(src0_plane, src0->get_nb(1)); + + const int64_t actual_row_count = + std::min(src0_plane_slice_row_count, + start_end_element.second - col_idx); // number of rows in this slice + if constexpr (_ShouldCacheSrc0) { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant); + + for (int64_t ir = 0; ir < actual_row_count; ir++) { + auto * src0_row = src0_plane + ir * src0->get_nb(1); + if (ir + 1 < actual_row_count) { + hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1)); + } + + auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size; + dequantize_row_func( + src0_row, reinterpret_cast(cached_row_ptr), src0->get_ne(0)); + } + + src0_plane = src0_plane_cache_ptr; + } + + { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot); + auto * dst_row = reinterpret_cast(dst_ptr) + col_idx; + int64_t i0 = 0; + for (; i0 + 1 < actual_row_count; i0 += 2) { + auto * src0_row = src0_plane + i0 * src0_actual_row_size; + if constexpr (should_fetch_src0_row) { + hexagon::l2fetch_row(src0_row + src0_actual_row_size, valid_row0_bytes); + } + + // TODO: figure dst how to handle a entire row + auto res0 = _DotFunc(reinterpret_cast(src0_row), + reinterpret_cast(src1_ptr), + (size_t) src0->get_ne(0)); + + if (should_fetch_src0_row && i0 + 2 < actual_row_count) { + hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes); + } + + // TODO: figure dst how to handle a entire row + auto res1 = _DotFunc(reinterpret_cast(src0_row + src0_actual_row_size), + reinterpret_cast(src1_ptr), + (size_t) src0->get_ne(0)); + + { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store); + dst_row[i0] = convert_vector::convert(res0); + dst_row[i0 + 1] = convert_vector::convert(res1); + } + } + + if (i0 < actual_row_count) { + auto * src0_row = src0_plane + i0 * src0_actual_row_size; + auto res = _DotFunc(reinterpret_cast(src0_row), + reinterpret_cast(src1_ptr), + (size_t) src0->get_ne(0)); + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store); + dst_row[i0] = convert_vector::convert(res); + } + } + } + } + + dst->release_write_buffer(); // mark the output tensor as modified +} + +bool is_src_cacheable(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) { + const auto & src0_type_traits = hexagon::get_type_traits(src0.type); + if (src0_type_traits.to_float == nullptr) { + DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) cannot be cached, to_float is null\n", + hexagon::get_type_name(src0.type)); return false; } - const size_t type_size = type_traits.is_quantized ? sizeof(hexagon::dequant_target_type) : type_traits.type_size; const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota(); - if (src.ne[0] * type_size > vtcm_thread_quota_size) { - DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n", - hexagon::get_type_name(src.type), (long) src.ne[0], vtcm_thread_quota_size); + const size_t src0_type_size = + src0_type_traits.is_quantized ? sizeof(hexagon::dequant_output_type) : src0_type_traits.type_size; + const auto & src1_type_traits = hexagon::get_type_traits(src1.type); + const bool is_gemv = src1.ne[1] == 1 && src1.ne[2] == 1 && src1.ne[3] == 1; + size_t min_cache_size = is_gemv ? (src1.ne[0] * src1_type_traits.type_size) : 0; + min_cache_size += src0.ne[0] * src0_type_size; + if (min_cache_size > vtcm_thread_quota_size) { + DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) min_cache_size is too large: %ld, vtcm_thread_quota_size: %zu\n", + hexagon::get_type_name(src0.type), + (long) min_cache_size, + vtcm_thread_quota_size); return false; } @@ -219,36 +407,41 @@ bool is_row_size_cacheable(const npu_device_tensor_spec & src) { bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) { if (src1.type != NPU_DATA_TYPE_F32 && src1.type != NPU_DATA_TYPE_F16) { DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src1 is not F32\n", - hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type)); + hexagon::get_type_name(src0.type), + hexagon::get_type_name(src1.type)); return false; } const auto type_traits = hexagon::get_type_traits(src0.type); if (!type_traits.is_quantized || type_traits.to_float == nullptr) { DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src0 is not quantized\n", - hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type)); + hexagon::get_type_name(src0.type), + hexagon::get_type_name(src1.type)); return false; } if (src0.ne[0] % type_traits.blck_size) { - DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) ne[0] is not aligned: %ld\n", hexagon::get_type_name(src0.type), - (long) src0.ne[0]); + DEVICE_LOG_DEBUG( + "[MUL_MAT]src0.type(%s) ne[0] is not aligned: %ld\n", hexagon::get_type_name(src0.type), (long) src0.ne[0]); return false; } - if (!is_row_size_cacheable(src0)) { + if (!is_src_cacheable(src0, src1)) { return false; } DEVICE_LOG_DEBUG("[MUL_MAT]supported quantized src0.type(%s) and src1.type(%s)\n", - hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type)); + hexagon::get_type_name(src0.type), + hexagon::get_type_name(src1.type)); return true; } -bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) { - const auto * src1_ptr = src1->get_read_buffer_as(); - const auto * src0_ptr = - is_src0_quantized ? nullptr : src0->get_read_buffer_as(); // skip src0 for quantized tensors +bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0, + hexagon::tensor * src1, + bool is_src0_cached, + bool is_src1_cached) { + const auto * src1_ptr = is_src1_cached ? nullptr : src1->get_read_buffer_as(); + const auto * src0_ptr = is_src0_cached ? nullptr : src0->get_read_buffer_as(); if (!hexagon::is_f16_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) { DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0)); @@ -305,35 +498,49 @@ bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::ten return true; } -typedef void (*mul_mat_func_type)(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst, +typedef void (*mul_mat_func_type)(hexagon::tensor * src0, + hexagon::tensor * src1, + hexagon::tensor * dst, hexagon::compute_params * params); -constexpr const mul_mat_func_type kMulMatF32F32CachedFuncs[2] = { +constexpr const size_t kMulMatGemvBaseIndex = 2; + +constexpr const mul_mat_func_type kMulMatF32F32CachedFuncs[4] = { // quantized and non-quantized - mul_mat_impl, // F32 * F32 quantized unaligned - mul_mat_impl, // F32 * F32 quantized aligned + mul_mat_impl, // F32 * F32 quantized unaligned + mul_mat_impl, // F32 * F32 quantized aligned + mul_mat_gemv_impl, // F32 * F32 quantized gemv + mul_mat_gemv_impl, // F32 * F32 quantized gemv }; -constexpr const mul_mat_func_type kMulMatF32F32Funcs[2] = { +constexpr const mul_mat_func_type kMulMatF32F32Funcs[4] = { // quantized and non-quantized - mul_mat_impl, // F32 * F32 quantized unaligned - mul_mat_impl, // F32 * F32 quantized aligned + mul_mat_impl, // F32 * F32 quantized unaligned + mul_mat_impl, // F32 * F32 quantized aligned + mul_mat_gemv_impl, // F32 * F32 quantized gemv + mul_mat_gemv_impl, // F32 * F32 quantized gemv }; -constexpr const mul_mat_func_type kMulMatF16CachedFuncs[2] = { - mul_mat_impl, // F16 * F16 quantized unaligned - mul_mat_impl, // F16 * F16 quantized aligned -}; - -constexpr const mul_mat_func_type kMulMatF16Funcs[2] = { - mul_mat_impl, // F16 * F16 quantized unaligned - mul_mat_impl, // F16 * F16 quantized aligned -}; - -constexpr const mul_mat_func_type kMulMatF16F32Funcs[2] = { +constexpr const mul_mat_func_type kMulMatF16F32Funcs[4] = { // quantized and non-quantized - mul_mat_impl, // F32 * F32 quantized unaligned - mul_mat_impl, // F32 * F32 quantized aligned + mul_mat_impl, // F32 * F32 quantized unaligned + mul_mat_impl, // F32 * F32 quantized aligned + mul_mat_gemv_impl, // F32 * F32 quantized unaligned + mul_mat_gemv_impl, // F32 * F32 quantized aligned +}; + +constexpr const mul_mat_func_type kMulMatF16CachedFuncs[4] = { + mul_mat_impl, // F16 * F16 quantized unaligned + mul_mat_impl, // F16 * F16 quantized aligned + mul_mat_gemv_impl, // F16 * F16 quantized gemv + mul_mat_gemv_impl, // F16 * F16 quantized gemv +}; + +constexpr const mul_mat_func_type kMulMatF16Funcs[4] = { + mul_mat_impl, // F16 * F16 quantized unaligned + mul_mat_impl, // F16 * F16 quantized aligned + mul_mat_gemv_impl, // F16 * F16 quantized gemv + mul_mat_gemv_impl, // F16 * F16 quantized gemv }; } // namespace @@ -342,9 +549,9 @@ namespace hexagon { bool mul_mat_f32(hexagon::tensor * out, compute_params * params) { static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "mul_mat_f32 requires max dims 4"); - static_assert(std::is_same::value || - std::is_same::value, - "dequant_target_type must be float or npu_device_fp16_t"); + static_assert(std::is_same::value || + std::is_same::value, + "dequant_output_type must be float or npu_device_fp16_t"); if (!out) { return false; @@ -358,24 +565,28 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) { const bool is_src0_quantized = is_quantized_type(src0->get_type()); const bool should_cache_src0 = is_src0_quantized || src1->get_ne(1) > 1; + const bool is_gemv = src1->get_ne(1) == 1 && src1->get_ne(2) == 1 && src1->get_ne(3) == 1; + const auto base_index = is_gemv ? kMulMatGemvBaseIndex : 0; switch (src1->get_type()) { case NPU_DATA_TYPE_F32: if (is_src0_quantized || src0->get_type() == NPU_DATA_TYPE_F16) { - kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1, - out, params); + kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, is_src0_quantized, is_gemv) + + base_index](src0, src1, out, params); } else if (should_cache_src0) { - kMulMatF32F32CachedFuncs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params); + kMulMatF32F32CachedFuncs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1) + base_index]( + src0, src1, out, params); } else { - kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params); + kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1) + base_index]( + src0, src1, out, params); } return true; case NPU_DATA_TYPE_F16: if (should_cache_src0) { - kMulMatF16CachedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)]( - src0, src1, out, params); + kMulMatF16CachedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) + + base_index](src0, src1, out, params); } else { - kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1, out, - params); + kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) + base_index]( + src0, src1, out, params); } return true; default: @@ -386,8 +597,11 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) { return false; } -bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, - const npu_device_tensor_spec * srcs, size_t src_len) { +bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len) { + const auto op = op_spec->op; if (op != NPU_OP_MUL_MAT) { DEVICE_LOG_DEBUG("op is not MUL_MAT: %d\n", op); return false; @@ -408,7 +622,9 @@ bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec if (src0.type != src1.type) { if (src1.type == NPU_DATA_TYPE_F32 && src0.type == NPU_DATA_TYPE_F16) { DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch, but src0 is F16 and src1 is F32\n", - op_get_name(op), get_type_name(src0.type), get_type_name(src1.type)); + op_get_name(op), + get_type_name(src0.type), + get_type_name(src1.type)); return true; // F16 * F32 is supported } @@ -418,26 +634,40 @@ bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec } #else DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch and quantized tensors are not supported\n", - op_get_name(op), get_type_name(src0.type), get_type_name(src1.type)); + op_get_name(op), + get_type_name(src0.type), + get_type_name(src1.type)); return false; #endif } if (src0.ne[0] != src1.ne[0] || src0.ne[1] != dst->ne[0]) { - DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[0], - (long) src0.ne[1], (long) src1.ne[0], (long) src1.ne[1]); + DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n", + op_get_name(op), + (long) src0.ne[0], + (long) src0.ne[1], + (long) src1.ne[0], + (long) src1.ne[1]); return false; } if (src1.ne[1] != dst->ne[1] || src1.ne[2] != dst->ne[2] || src1.ne[3] != dst->ne[3]) { - DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n", op_get_name(op), - (long) src1.ne[2], (long) src1.ne[3], (long) dst->ne[2], (long) dst->ne[3]); + DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n", + op_get_name(op), + (long) src1.ne[2], + (long) src1.ne[3], + (long) dst->ne[2], + (long) dst->ne[3]); return false; } if (src1.ne[2] % src0.ne[2] || src1.ne[3] % src0.ne[3]) { - DEVICE_LOG_DEBUG("[%s]src0 cannot broadcast to src1: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[2], - (long) src0.ne[3], (long) src1.ne[2], (long) src1.ne[3]); + DEVICE_LOG_DEBUG("[%s]src0 cannot broadcast to src1: %ldx%ld vs %ldx%ld\n", + op_get_name(op), + (long) src0.ne[2], + (long) src0.ne[3], + (long) src1.ne[2], + (long) src1.ne[3]); return false; } diff --git a/ggml/src/ggml-qnn/npu/device/op_mul_mat.hpp b/ggml/src/ggml-qnn/npu/device/op_mul_mat.hpp index 434406f930..1b8350627d 100644 --- a/ggml/src/ggml-qnn/npu/device/op_mul_mat.hpp +++ b/ggml/src/ggml-qnn/npu/device/op_mul_mat.hpp @@ -1,14 +1,16 @@ #pragma once -#include - #include "op_types.hpp" #include "tensor.hpp" +#include + namespace hexagon { bool mul_mat_f32(tensor * out, compute_params * params); -bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, - const npu_device_tensor_spec * srcs, size_t src_len); +bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op_rope.cpp b/ggml/src/ggml-qnn/npu/device/op_rope.cpp index 34bd0409db..27a35394c5 100644 --- a/ggml/src/ggml-qnn/npu/device/op_rope.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_rope.cpp @@ -29,8 +29,14 @@ float rope_yarn_ramp(const float low, const float high, const int i0) { // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta) { +void rope_yarn(float theta_extrap, + float freq_scale, + float corr_dims[2], + int64_t i0, + float ext_factor, + float mscale, + float * cos_theta, + float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -45,8 +51,16 @@ void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t *sin_theta = sinf(theta) * mscale; } -void rope_cache_init(float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, - float ext_factor, float mscale, float * cache, float sin_sign, float theta_scale) { +void rope_cache_init(float theta_base, + float freq_scale, + const float * freq_factors, + float corr_dims[2], + int64_t ne0, + float ext_factor, + float mscale, + float * cache, + float sin_sign, + float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py float theta = theta_base; for (int64_t i0 = 0; i0 < ne0; i0 += 2) { @@ -58,10 +72,21 @@ void rope_cache_init(float theta_base, float freq_scale, const float * freq_fact } } -void mrope_cache_init(float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, - const int sections[4], bool indep_sects, float freq_scale, const float * freq_factors, - float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float * cache, float sin_sign, - float theta_scale) { +void mrope_cache_init(float theta_base_t, + float theta_base_h, + float theta_base_w, + float theta_base_e, + const int sections[4], + bool indep_sects, + float freq_scale, + const float * freq_factors, + float corr_dims[2], + int64_t ne0, + float ext_factor, + float mscale, + float * cache, + float sin_sign, + float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py float theta_t = theta_base_t; float theta_h = theta_base_h; @@ -181,16 +206,37 @@ bool rope_impl(hexagon::tensor * out, hexagon::compute_params * params) { if constexpr (!_IsMrope) { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache); const int64_t p = pos[i2]; - rope_cache_init(p, freq_scale, freq_factors, corr_dims, out->get_ne(0), ext_factor, attn_factor, cache, - sin_sign, theta_scale); + rope_cache_init(p, + freq_scale, + freq_factors, + corr_dims, + out->get_ne(0), + ext_factor, + attn_factor, + cache, + sin_sign, + theta_scale); } else { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache); const int64_t p_t = pos[i2]; const int64_t p_h = pos[i2 + out->get_ne(2)]; const int64_t p_w = pos[i2 + out->get_ne(2) * 2]; const int64_t p_e = pos[i2 + out->get_ne(2) * 3]; - mrope_cache_init(p_t, p_h, p_w, p_e, sections, _IsVision, freq_scale, freq_factors, corr_dims, - out->get_ne(0), ext_factor, attn_factor, cache, sin_sign, theta_scale); + mrope_cache_init(p_t, + p_h, + p_w, + p_e, + sections, + _IsVision, + freq_scale, + freq_factors, + corr_dims, + out->get_ne(0), + ext_factor, + attn_factor, + cache, + sin_sign, + theta_scale); } DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 1, loop); @@ -316,8 +362,11 @@ bool rope_f32(tensor * out, compute_params * params) { return kRopeImplFuncs[impl_index](out, params); } -bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, - size_t src_len) { +bool is_rope_supported(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len) { + const auto op = op_spec->op; if (op != NPU_OP_ROPE) { DEVICE_LOG_DEBUG("[%s]op is not ROPE\n", op_get_name(op)); return false; @@ -336,8 +385,10 @@ bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * d const auto & src0 = srcs[0]; if (src0.type != dst->type) { - DEVICE_LOG_DEBUG("[%s]src0 type is not the same as dst type: %s vs %s\n", op_get_name(op), - get_type_name(src0.type), get_type_name(dst->type)); + DEVICE_LOG_DEBUG("[%s]src0 type is not the same as dst type: %s vs %s\n", + op_get_name(op), + get_type_name(src0.type), + get_type_name(dst->type)); return false; // unsupported src0 type } diff --git a/ggml/src/ggml-qnn/npu/device/op_rope.hpp b/ggml/src/ggml-qnn/npu/device/op_rope.hpp index f2be465ae1..bba3034088 100644 --- a/ggml/src/ggml-qnn/npu/device/op_rope.hpp +++ b/ggml/src/ggml-qnn/npu/device/op_rope.hpp @@ -5,7 +5,9 @@ namespace hexagon { bool rope_f32(tensor * out, compute_params * params); -bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, - size_t src_len); +bool is_rope_supported(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op_types.hpp b/ggml/src/ggml-qnn/npu/device/op_types.hpp index bad83ad95e..6d2ee4c4b7 100644 --- a/ggml/src/ggml-qnn/npu/device/op_types.hpp +++ b/ggml/src/ggml-qnn/npu/device/op_types.hpp @@ -1,5 +1,11 @@ #pragma once +#include "hexagon_npu.h" +#include "tensor.hpp" +#include "thread_pool.hpp" +#include "util.hpp" +#include "vec_ops.hpp" + #include #include @@ -7,12 +13,6 @@ #include #include -#include "hexagon_npu.h" -#include "tensor.hpp" -#include "thread_pool.hpp" -#include "util.hpp" -#include "vec_ops.hpp" - namespace hexagon { inline constexpr std::pair get_thread_work_slice(int64_t total, size_t tidx, size_t tcnt) { @@ -44,8 +44,6 @@ struct compute_params { uint8_t * get_vtcm_cache(size_t size) { return thread_params->get_vtcm_cache(size); } - uint8_t * get_mem_cache(size_t size) { return thread_params->get_mem_cache(size); } - std::pair get_work_slice(int64_t total) const { return get_thread_work_slice(total, thread_params->tidx, thread_params->tcnt); } @@ -58,7 +56,9 @@ struct compute_params { }; typedef bool (*compute_func_type)(tensor * dst, compute_params * params); -typedef bool (*op_is_supported_func_type)(npu_device_tensor_op op, const npu_device_tensor_spec * dst, - const npu_device_tensor_spec * srcs, size_t src_len); +typedef bool (*op_is_supported_func_type)(const npu_device_tensor_op_spec * op_spec, + const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, + size_t src_len); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/tensor.hpp b/ggml/src/ggml-qnn/npu/device/tensor.hpp index c6a7fb1077..a6feefe2ec 100644 --- a/ggml/src/ggml-qnn/npu/device/tensor.hpp +++ b/ggml/src/ggml-qnn/npu/device/tensor.hpp @@ -1,13 +1,13 @@ #pragma once +#include "hexagon_npu.h" +#include "util.hpp" + #include #include #include -#include "hexagon_npu.h" -#include "util.hpp" - namespace hexagon { constexpr const size_t kMaxTensorSrc = DEVICE_TENSOR_MAX_SRC; @@ -26,8 +26,15 @@ class tensor { _data = static_cast(mmap_address); DEVICE_LOG_INFO("tensor(%p[%ldx%ldx%ldx%ld]), fd: %d, offset: %zu, mmap_addr: %p, phy_addr: 0x%lx\n", - (void *) this, (long) _info.ne[0], (long) _info.ne[1], (long) _info.ne[2], (long) _info.ne[3], - _info.buffer_fd, _info.offset, (void *) mmap_address, phy_address); + (void *) this, + (long) _info.ne[0], + (long) _info.ne[1], + (long) _info.ne[2], + (long) _info.ne[3], + _info.buffer_fd, + _info.offset, + (void *) mmap_address, + phy_address); } ~tensor() noexcept { @@ -41,15 +48,17 @@ class tensor { void flush() const { if (_data) { - qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size, QURT_MEM_CACHE_FLUSH, - QURT_MEM_DCACHE); + qurt_mem_cache_clean( + (qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size, QURT_MEM_CACHE_FLUSH, QURT_MEM_DCACHE); } } void invalidate() const { if (_data) { - qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size, - QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE); + qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset), + (qurt_size_t) _info.size, + QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, + QURT_MEM_DCACHE); } } diff --git a/ggml/src/ggml-qnn/npu/device/thread_pool.hpp b/ggml/src/ggml-qnn/npu/device/thread_pool.hpp index 455d4eec30..902bdcfc56 100644 --- a/ggml/src/ggml-qnn/npu/device/thread_pool.hpp +++ b/ggml/src/ggml-qnn/npu/device/thread_pool.hpp @@ -1,5 +1,8 @@ #pragma once +#include "util.hpp" +#include "vtcm_mem.hpp" + #include #include @@ -8,27 +11,26 @@ #include #include -#include "util.hpp" -#include "vtcm_mem.hpp" - namespace hexagon { -constexpr const size_t kMaxThreadCount = 4; -constexpr const size_t kDefaultStackSize = 1024 * 32; // 32KB -constexpr const unsigned long long kThreadTaskPendingBit = 1; +constexpr const size_t kMaxThreadCount = 4; +constexpr const size_t kDefaultStackSize = 1024 * 64; // 64KB template class qurt_thread { public: typedef void (*qurt_thread_func_type)(qurt_thread * thread, void * arg); - explicit qurt_thread(const std::string & thread_name, qurt_thread_func_type thread_func, void * arg, - unsigned short priority) { + explicit qurt_thread(const std::string & thread_name, + qurt_thread_func_type thread_func, + void * arg, + unsigned short priority) { DEVICE_LOG_DEBUG("qurt_thread.create: %s", thread_name.c_str()); qurt_thread_attr_init(&_attributes); qurt_thread_attr_set_name(&_attributes, (char *) thread_name.c_str()); qurt_thread_attr_set_stack_addr(&_attributes, _stack); qurt_thread_attr_set_stack_size(&_attributes, _stack_size); qurt_thread_attr_set_priority(&_attributes, priority); + qurt_thread_attr_set_bus_priority(&_attributes, QURT_THREAD_BUS_PRIO_ENABLED); _func = thread_func; _arg = arg; @@ -94,9 +96,9 @@ template class thread_pool { thread_pool * pool = nullptr; size_t vtcm_quota_size; - std::unique_ptr vtcm_cache; - std::unique_ptr mem_cache; - size_t mem_cache_size = 0; + std::unique_ptr vtcm_cache; + + void init_vtcm_cache() { vtcm_cache = std::make_unique(vtcm_quota_size, false); } uint8_t * get_vtcm_cache(size_t size) { if (!vtcm_cache || vtcm_cache->get_size() < size) { @@ -111,25 +113,18 @@ template class thread_pool { return vtcm_cache->get_mem(); } - - uint8_t * get_mem_cache(size_t size) { - if (!mem_cache || mem_cache_size < size) { - mem_cache.reset(); // reset the cache to create a new one - mem_cache = std::make_unique(size + 256); - mem_cache_size = mem_cache ? size : 0; - } - - return mem_cache.get(); - } }; typedef void (*task_type)(thread_pool * pool, thread_params * param, void * arg); thread_pool() { for (size_t i = 0; i < kMaxThreadCount; ++i) { - _thread_params[i].tidx = i; - _thread_params[i].vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount; - _thread_params[i].pool = this; + auto & thread_param = _thread_params[i]; + thread_param.tidx = i; + thread_param.vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount; + thread_param.pool = this; + + thread_param.init_vtcm_cache(); } qurt_barrier_init(&_pending, kMaxSubThreadCount + 1); @@ -215,7 +210,8 @@ template class thread_pool { #ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING auto task_begin_cycles = pool._task_begin_cycles.load(); - DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, prepare: %lluus", param->tidx, + DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, prepare: %lluus", + param->tidx, static_cast( HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles))); #endif @@ -229,7 +225,8 @@ template class thread_pool { qurt_barrier_wait(&pool._completed); #ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING - DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, task_end: %lluus", param->tidx, + DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, task_end: %lluus", + param->tidx, static_cast( HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles))); #endif diff --git a/ggml/src/ggml-qnn/npu/device/type_traits.cpp b/ggml/src/ggml-qnn/npu/device/type_traits.cpp index 31377f6e55..3350167749 100644 --- a/ggml/src/ggml-qnn/npu/device/type_traits.cpp +++ b/ggml/src/ggml-qnn/npu/device/type_traits.cpp @@ -1,12 +1,12 @@ #include "type_traits.hpp" +#include "op_types.hpp" // TODO: remove this include +#include "vec_ops.hpp" + #include #include -#include "op_types.hpp" // TODO: remove this include -#include "vec_ops.hpp" - static_assert(sizeof(npu_device_block_q4_k) == 2 * sizeof(npu_device_fp16_t) + QUANT_K_SCALE_SIZE + QUANT_K_BLOCK_SIZE / 2, "wrong q4_K block size/padding"); @@ -82,8 +82,17 @@ inline int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } -float make_qkx2_quants(int n, int nmax, const float * x, const float * weights, uint8_t * L, float * the_min, - uint8_t * Laux, float rmin, float rdelta, int nstep, bool use_mad) { +float make_qkx2_quants(int n, + int nmax, + const float * x, + const float * weights, + uint8_t * L, + float * the_min, + uint8_t * Laux, + float rmin, + float rdelta, + int nstep, + bool use_mad) { float min = x[0]; float max = x[0]; float sum_w = weights[0]; @@ -315,13 +324,13 @@ void quantize_row_q4_K(const float * src, void * dst, size_t count) { } } -void dequantize_row_q8_0(const void * src, hexagon::dequant_target_type * dst, size_t count) { +void dequantize_row_q8_0(const void * src, hexagon::dequant_output_type * dst, size_t count) { constexpr const int qk = QUANT_BLOCK_SIZE; static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); const int nb = count / qk; const auto * src_ptr = reinterpret_cast(src); - auto * dst_ptr = ((hexagon::dequant_target_type *) dst); // TODO: opt for aligned access + auto * dst_ptr = ((hexagon::dequant_output_type *) dst); // TODO: opt for aligned access int i = 0; for (; i + 1 < nb; i += 2) { @@ -354,7 +363,7 @@ void dequantize_row_q8_0(const void * src, hexagon::dequant_target_type * dst, s } template -void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * dst, size_t count) { +void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * dst, size_t count) { constexpr const int qk = QUANT_BLOCK_SIZE; static_assert(qk % 2 == 0, "qk must be even"); static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); @@ -364,7 +373,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d const auto * src_ptr = reinterpret_cast(src); const HVX_Vector mask = Q6_Vb_vsplat_R(0x0F); const HVX_Vector minus = Q6_Vb_vsplat_R(8); - hexagon::dequant_target_type * dst_ptr = dst; // TODO: opt for aligned access + hexagon::dequant_output_type * dst_ptr = dst; // TODO: opt for aligned access int i = 0; for (; i + 3 < nb; i += 4) { @@ -402,7 +411,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d reinterpret_cast(dst_ptr)[1] = q_hi; } - dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type) * 2; + dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type) * 2; } for (; i + 1 < nb; i += 2) { @@ -428,7 +437,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d *reinterpret_cast(dst_ptr) = q_lo; } - dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type); + dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type); } if (i < nb) { @@ -453,7 +462,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d } } -void dequantize_row_q4_0(const void * src, hexagon::dequant_target_type * dst, size_t count) { +void dequantize_row_q4_0(const void * src, hexagon::dequant_output_type * dst, size_t count) { const bool dst_aligned = hexagon::is_addr_aligned(dst); if (dst_aligned) { dequantize_row_q4_0_impl(src, dst, count); @@ -462,7 +471,7 @@ void dequantize_row_q4_0(const void * src, hexagon::dequant_target_type * dst, s } } -void dequantize_row_q4_K(const void * src, hexagon::dequant_target_type * dst, size_t count) { +void dequantize_row_q4_K(const void * src, hexagon::dequant_output_type * dst, size_t count) { const int nb = count / QUANT_K_BLOCK_SIZE; const auto * src_ptr = reinterpret_cast(src); auto * dst_ptr = reinterpret_cast<__fp16 *>(dst); @@ -497,29 +506,44 @@ void dequantize_row_q4_K(const void * src, hexagon::dequant_target_type * dst, s } } -void copy_row_f16(const void * src, hexagon::dequant_target_type * dst, size_t count) { +void copy_row_f16(const void * src, hexagon::dequant_output_type * dst, size_t count) { hexagon::vec_cpy_f16(reinterpret_cast(src), dst, count); } -void copy_row_f32(const void * src, hexagon::dequant_target_type * dst, size_t count) { +void copy_row_f32(const void * src, hexagon::dequant_output_type * dst, size_t count) { hexagon::vec_cpy_f32(reinterpret_cast(src), reinterpret_cast(dst), count); } constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = { - { NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32, nullptr, - hexagon::type_erase_dot_func, + { NPU_DATA_TYPE_F32, + "F32", 1, + sizeof(float), + false, copy_row_f32, + nullptr, hexagon::type_erase_dot_func, hexagon::type_erase_dot_func, hexagon::type_erase_dot_func }, - { NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, copy_row_f16, quantize_row_fp16, - hexagon::type_erase_dot_func, + { NPU_DATA_TYPE_F16, + "F16", 1, + sizeof(npu_device_fp16_t), + false, copy_row_f16, + quantize_row_fp16, hexagon::type_erase_dot_func, hexagon::type_erase_dot_func, hexagon::type_erase_dot_func }, { NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false }, - { NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0, + { NPU_DATA_TYPE_Q8_0, + "Q8_0", QUANT_BLOCK_SIZE, + sizeof(npu_device_block_q8_0), + true, dequantize_row_q8_0, quantize_row_q8_0 }, - { NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0, + { NPU_DATA_TYPE_Q4_0, + "Q4_0", QUANT_BLOCK_SIZE, + sizeof(npu_device_block_q4_0), + true, dequantize_row_q4_0, quantize_row_q4_0 }, - { NPU_DATA_TYPE_Q4_K, "Q4_K", QUANT_K_BLOCK_SIZE, sizeof(npu_device_block_q4_k), true, dequantize_row_q4_K, + { NPU_DATA_TYPE_Q4_K, + "Q4_K", QUANT_K_BLOCK_SIZE, + sizeof(npu_device_block_q4_k), + true, dequantize_row_q4_K, quantize_row_q4_K }, }; @@ -566,7 +590,7 @@ size_t get_dequantized_row_size(const tensor * tensor) { auto row_elems_count = tensor->get_ne(0); return hexagon::get_aligned_size( - row_elems_count * sizeof(dequant_target_type)); // dequant_target_type is currently restricted to f32 + row_elems_count * sizeof(dequant_output_type)); // dequant_output_type is currently restricted to f32 } } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/type_traits.hpp b/ggml/src/ggml-qnn/npu/device/type_traits.hpp index 645101a676..363827de0a 100644 --- a/ggml/src/ggml-qnn/npu/device/type_traits.hpp +++ b/ggml/src/ggml-qnn/npu/device/type_traits.hpp @@ -5,12 +5,12 @@ namespace hexagon { -using dequant_target_type = npu_device_fp16_t; +using dequant_output_type = npu_device_fp16_t; bool init_f16_f32_table(float * table, size_t count); typedef void (*quantize_row_type)(const float * src, void * dst, size_t count); -typedef void (*dequantize_row_type)(const void * src, dequant_target_type * dst, size_t count); +typedef void (*dequantize_row_type)(const void * src, dequant_output_type * dst, size_t count); typedef float (*vec_dot_type)(const void * src0, const void * src1, size_t count); typedef bool (*can_use_aligned_vec_dot_type)(const void * src0, const void * src1, size_t count); @@ -51,14 +51,32 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx) { auto * src1 = op->get_src(1); char buffer[1024]; if (src1 == nullptr) { - snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s], tidx: %zu", op_get_name(op->get_op()), - src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3), get_type_name(src0->get_type()), + snprintf(buffer, + sizeof(buffer), + "[%s][%lldx%lldx%lldx%lld%s], tidx: %zu", + op_get_name(op->get_op()), + src0->get_ne(0), + src0->get_ne(1), + src0->get_ne(2), + src0->get_ne(3), + get_type_name(src0->get_type()), tidx); } else { - snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s],[%lldx%lldx%lldx%lld%s], tidx: %zu", - op_get_name(op->get_op()), src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3), - get_type_name(src0->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2), src1->get_ne(3), - get_type_name(src1->get_type()), tidx); + snprintf(buffer, + sizeof(buffer), + "[%s][%lldx%lldx%lldx%lld%s],[%lldx%lldx%lldx%lld%s], tidx: %zu", + op_get_name(op->get_op()), + src0->get_ne(0), + src0->get_ne(1), + src0->get_ne(2), + src0->get_ne(3), + get_type_name(src0->get_type()), + src1->get_ne(0), + src1->get_ne(1), + src1->get_ne(2), + src1->get_ne(3), + get_type_name(src1->get_type()), + tidx); } return npu_scoped_timer<1024>(buffer); } diff --git a/ggml/src/ggml-qnn/npu/device/util.hpp b/ggml/src/ggml-qnn/npu/device/util.hpp index 4fdcc786ba..d70d740180 100644 --- a/ggml/src/ggml-qnn/npu/device/util.hpp +++ b/ggml/src/ggml-qnn/npu/device/util.hpp @@ -1,5 +1,7 @@ #pragma once +#include "hexagon_npu.h" + #include #include #include @@ -9,8 +11,6 @@ #include #include -#include "hexagon_npu.h" - #define DEVICE_LOG_ERROR(...) FARF(FATAL, __VA_ARGS__) #define DEVICE_LOG_WARN(...) FARF(ERROR, __VA_ARGS__) #define DEVICE_LOG_INFO(...) FARF(HIGH, __VA_ARGS__) @@ -56,6 +56,8 @@ inline constexpr const char * op_get_name(npu_device_tensor_op op) { return "FLASH_ATTN_EXT"; case NPU_OP_ROPE: return "ROPE"; + case NPU_OP_GLU: + return "GLU"; default: return "UNKNOWN"; } @@ -252,44 +254,68 @@ template class npu_scoped_timer { "[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, " "[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, " "[%s]cnt: %llu, dur: %lluus\n", - _log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration, - _sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count, - (unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix, - (unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration, - _sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count, - (unsigned long long) sub_proc2_duration, _sub_proc_data[3].log_prefix, - (unsigned long long) _sub_proc_data[3].proc_count, (unsigned long long) sub_proc3_duration); + _log_prefix, + (unsigned long long) total_pcycles, + (unsigned long long) duration, + _sub_proc_data[0].log_prefix, + (unsigned long long) _sub_proc_data[0].proc_count, + (unsigned long long) sub_proc0_duration, + _sub_proc_data[1].log_prefix, + (unsigned long long) _sub_proc_data[1].proc_count, + (unsigned long long) sub_proc1_duration, + _sub_proc_data[2].log_prefix, + (unsigned long long) _sub_proc_data[2].proc_count, + (unsigned long long) sub_proc2_duration, + _sub_proc_data[3].log_prefix, + (unsigned long long) _sub_proc_data[3].proc_count, + (unsigned long long) sub_proc3_duration); break; case 3: DEVICE_LOG_WARN( "[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, " "[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", - _log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration, - _sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count, - (unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix, - (unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration, - _sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count, + _log_prefix, + (unsigned long long) total_pcycles, + (unsigned long long) duration, + _sub_proc_data[0].log_prefix, + (unsigned long long) _sub_proc_data[0].proc_count, + (unsigned long long) sub_proc0_duration, + _sub_proc_data[1].log_prefix, + (unsigned long long) _sub_proc_data[1].proc_count, + (unsigned long long) sub_proc1_duration, + _sub_proc_data[2].log_prefix, + (unsigned long long) _sub_proc_data[2].proc_count, (unsigned long long) sub_proc2_duration); break; case 2: DEVICE_LOG_WARN( "[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, " "[%s]cnt: %llu, dur: %lluus\n", - _log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration, - _sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count, - (unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix, - (unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration); + _log_prefix, + (unsigned long long) total_pcycles, + (unsigned long long) duration, + _sub_proc_data[0].log_prefix, + (unsigned long long) _sub_proc_data[0].proc_count, + (unsigned long long) sub_proc0_duration, + _sub_proc_data[1].log_prefix, + (unsigned long long) _sub_proc_data[1].proc_count, + (unsigned long long) sub_proc1_duration); break; case 1: - DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", _log_prefix, - (unsigned long long) total_pcycles, (unsigned long long) duration, - _sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count, + DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", + _log_prefix, + (unsigned long long) total_pcycles, + (unsigned long long) duration, + _sub_proc_data[0].log_prefix, + (unsigned long long) _sub_proc_data[0].proc_count, (unsigned long long) sub_proc0_duration); break; default: case 0: - DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", _log_prefix, - (unsigned long long) total_pcycles, (unsigned long long) duration); + DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", + _log_prefix, + (unsigned long long) total_pcycles, + (unsigned long long) duration); break; } } @@ -317,8 +343,8 @@ template class npu_sub_process_scoped_ti } ~npu_sub_process_scoped_timer() { - _timer.add_sub_proc_cycles(_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles, - HAP_perf_get_pcycles() - _begin_pcycles); + _timer.add_sub_proc_cycles( + _sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles, HAP_perf_get_pcycles() - _begin_pcycles); } private: diff --git a/ggml/src/ggml-qnn/npu/device/vec_math.inl b/ggml/src/ggml-qnn/npu/device/vec_math.inl new file mode 100644 index 0000000000..ab7f01cf1b --- /dev/null +++ b/ggml/src/ggml-qnn/npu/device/vec_math.inl @@ -0,0 +1,1129 @@ +#pragma once + +#include "hexagon_npu.h" + +#include + +#include + +// TODO: move this macros to a common header +#define IEEE_VSF_EXPLEN (8) +#define IEEE_VSF_EXPBIAS (127) +#define IEEE_VSF_EXPMASK (0xFF) +#define IEEE_VSF_MANTLEN (23) +#define IEEE_VSF_MANTMASK (0x7FFFFF) +#define IEEE_VSF_MIMPMASK (0x800000) + +#define IEEE_VHF_EXPLEN (5) +#define IEEE_VHF_EXPBIAS (15) +#define IEEE_VHF_EXPMASK (0x1F) +#define IEEE_VHF_MANTLEN (10) +#define IEEE_VHF_MANTMASK (0x3FF) +#define IEEE_VHF_MIMPMASK (0x400) + +#define COEFF_EXP_5 0x39506967 // 0.000198757 = 1/(7!) +#define COEFF_EXP_4 0x3AB743CE // 0.0013982 = 1/(6!) +#define COEFF_EXP_3 0x3C088908 // 0.00833345 = 1/(5!) +#define COEFF_EXP_2 0x3D2AA9C1 // 0.416658 = 1/(4!) +#define COEFF_EXP_1 0x3E2AAAAA // 0.16666667 = 1/(3!) +#define COEFF_EXP_0 0x3F000000 // 0.5 = 1/(2!) +#define LOGN2 0x3F317218 // ln(2) = 0.6931471805 +#define LOG2E 0x3FB8AA3B // log2(e) = 1/ln(2) = 1.4426950408 + +#define COEFF_EXP_5_HF 0x0A83 // 0.000198757 = 1/(7!) +#define COEFF_EXP_4_HF 0x15BA // 0.0013982 = 1/(6!) +#define COEFF_EXP_3_HF 0x2044 // 0.00833345 = 1/(5!) +#define COEFF_EXP_2_HF 0x36AB // 0.416658 = 1/(4!) +#define COEFF_EXP_1_HF 0x3155 // 0.16666667 = 1/(3!) +#define COEFF_EXP_0_HF 0x3800 // 0.5 = 1/(2!) +#define LOGN2_HF 0x398C // ln(2) = 0.693147 +#define LOG2E_HF 0x3DC5 // log2(e) = 1/ln(2) = 1.4427 + +namespace hexagon::vec::math { + +inline HVX_Vector qhmath_hvx_vsf_floor_vsf(HVX_Vector vin) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN); + HVX_Vector const_zero_v = Q6_V_vzero(); + HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf + + // initialization (no changes) + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, vin); + + HVX_Vector expval_v = vin >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v); + HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, vin, const_zero_v); + HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, vin); + + // if expval < 0 (q_negexp) // <0, floor is 0 + // if vin > 0 + // floor = 0 + // if vin < 0 + // floor = -1 + // if expval < mant_len (q_expltmn) // >0, but fraction may exist + // get sign (q_negative) + // mask >> expval // fraction bits to mask off + // vout = ~(mask) // apply mask to remove fraction + // if (qneg) // negative floor is one less (more, sign bit for neg) + // vout += ((impl_mask) >> expval) + // if (mask && vin) + // vout = vin + // else // already an integer + // ; // no change + + // compute floor + mask_mant_v >>= expval_v; + HVX_Vector neg_addin_v = mask_impl_v >> expval_v; + HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(vin, neg_addin_v); + HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, vin); + + HVX_Vector mask_chk_v = Q6_V_vand_VV(vin, mask_mant_v); // chk if bits set + HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v); + + HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear + HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits + + vout = vin; + vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval0 -> 0 + vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1 + return vout; +} + +// truncate(x) +// given a vector of float x, +// return the vector of integers resulting from dropping all fractional bits +// no checking performed for overflow - could be extended to return maxint +// +// truncate float to int +inline HVX_Vector qhmath_hvx_vw_truncate_vsf(HVX_Vector vin) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_zero_v = Q6_V_vzero(); + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, vin); + + HVX_Vector expval_v = vin >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + // negative exp == fractional value + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + + HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift + + HVX_Vector mant_v = vin & mask_mant_v; // obtain mantissa + HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0 + vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer + vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0 + + HVX_Vector neg_vout = -vout; + vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives + return (vout); +} + +// qhmath_hvx_vhf_floor_vhf(x) +// given a vector of half float x, +// return the vector of largest integer valued half float <= x +// +inline HVX_Vector qhmath_hvx_vhf_floor_vhf(HVX_Vector vin) { + HVX_Vector mask_mant_v = Q6_Vh_vsplat_R(IEEE_VHF_MANTMASK); + HVX_Vector mask_impl_v = Q6_Vh_vsplat_R(IEEE_VHF_MIMPMASK); + HVX_Vector const_mnlen_v = Q6_Vh_vsplat_R(IEEE_VHF_MANTLEN); + HVX_Vector const_emask_v = Q6_Vh_vsplat_R(IEEE_VHF_EXPMASK); + HVX_Vector const_ebias_v = Q6_Vh_vsplat_R(IEEE_VHF_EXPBIAS); + HVX_Vector const_zero_v = Q6_V_vzero(); + HVX_Vector const_negone_v = Q6_Vh_vsplat_R(0xbc00); // -1 IEEE vhf + + // initialization (no changes) + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VhVh(const_zero_v, vin); + + HVX_Vector expval_v = Q6_Vh_vasr_VhR(vin, IEEE_VHF_MANTLEN); + expval_v = Q6_V_vand_VV(expval_v, const_emask_v); + expval_v = Q6_Vh_vsub_VhVh(expval_v, const_ebias_v); + + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VhVh(const_zero_v, expval_v); + HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VhVh(const_mnlen_v, expval_v); + HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVhVh(q_negexp, vin, const_zero_v); + HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVhVh(q_negexp, const_zero_v, vin); + + // if expval < 0 (q_negexp) // <0, floor is 0 + // if vin > 0 + // floor = 0 + // if vin < 0 + // floor = -1 + // if expval < mant_len (q_expltmn) // >0, but fraction may exist + // get sign (q_negative) + // mask >> expval // fraction bits to mask off + // vout = ~(mask) // apply mask to remove fraction + // if (qneg) // negative floor is one less (more, sign bit for neg) + // vout += ((impl_mask) >> expval) + // if (mask && vin) + // vout = vin + // else // already an integer + // ; // no change + + // compute floor + mask_mant_v = Q6_Vh_vasr_VhVh(mask_mant_v, expval_v); + HVX_Vector neg_addin_v = Q6_Vh_vasr_VhVh(mask_impl_v, expval_v); + HVX_Vector vout_neg_addin = Q6_Vh_vadd_VhVh(vin, neg_addin_v); + HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, vin); + + HVX_Vector mask_chk_v = Q6_V_vand_VV(vin, mask_mant_v); // chk if bits set + HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VhVh(const_zero_v, mask_chk_v); + + HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear + HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits + + vout = vin; + vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval0 -> 0 + vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1 + return vout; +} + +// truncate half float to short +inline HVX_Vector qhmath_hvx_vh_truncate_vhf(HVX_Vector vin) { + HVX_Vector const_mnlen_v = Q6_Vh_vsplat_R(IEEE_VHF_MANTLEN); + HVX_Vector mask_mant_v = Q6_Vh_vsplat_R(IEEE_VHF_MANTMASK); + HVX_Vector mask_impl_v = Q6_Vh_vsplat_R(IEEE_VHF_MIMPMASK); + HVX_Vector const_emask_v = Q6_Vh_vsplat_R(IEEE_VHF_EXPMASK); + HVX_Vector const_ebias_v = Q6_Vh_vsplat_R(IEEE_VHF_EXPBIAS); + HVX_Vector const_zero_v = Q6_V_vzero(); + HVX_Vector const_one_v = Q6_Vh_vsplat_R(1); + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VhVh(const_zero_v, vin); + + HVX_Vector expval_v = Q6_Vh_vasr_VhVh(vin, const_mnlen_v); + expval_v = Q6_V_vand_VV(expval_v, const_emask_v); + expval_v = Q6_Vh_vsub_VhVh(expval_v, const_ebias_v); + + // negative exp == fractional value + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VhVh(const_zero_v, expval_v); + + // fractional bits - exp shift + HVX_Vector rshift_v = Q6_Vh_vsub_VhVh(const_mnlen_v, expval_v); + + HVX_Vector mant_v = vin & mask_mant_v; // obtain mantissa + HVX_Vector vout = Q6_Vh_vadd_VhVh(mant_v, mask_impl_v); // add implicit 1.0 + vout = Q6_Vh_vasr_VhVh(vout, rshift_v); // shift to obtain truncated integer + vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0 + + // HVX_Vector neg_vout = -vout; + HVX_Vector not_vout = Q6_V_vnot_V(vout); + HVX_Vector neg_vout = Q6_Vh_vadd_VhVh(not_vout, const_one_v); + vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives + return (vout); +} + +/* + * This function computes the exponent on all IEEE 32-bit float elements of an HVX_Vector + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_Vector qhmath_hvx_exp_vf(HVX_Vector sline) { + HVX_Vector z_qf32_v; + HVX_Vector x_v; + HVX_Vector x_qf32_v; + HVX_Vector y_v; + HVX_Vector k_v; + HVX_Vector f_v; + HVX_Vector epsilon_v; + HVX_Vector log2e = Q6_V_vsplat_R(LOG2E); + HVX_Vector logn2 = Q6_V_vsplat_R(LOGN2); + HVX_Vector E_const; + HVX_Vector zero_v = Q6_V_vzero(); + + // 1) clipping + uint input + // if (x > MAXLOG) + // return (MAXNUM); + // if (x < MINLOG) + // return (0.0); + // + // 2) exp(x) is approximated as follows: + // f = floor(x/ln(2)) = floor(x*log2(e)) + // epsilon = x - f*ln(2) + // exp(x) = exp(epsilon+f*ln(2)) + // = exp(epsilon)*exp(f*ln(2)) + // = exp(epsilon)*2^f + // Since epsilon is close to zero, it can be approximated with its Taylor series: + // exp(x)~=1+x+x^2/2!+x^3/3!+...+x^n/n!+... + // Preserving the first eight elements, we get: + // exp(x)~=1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7 + // =1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2 + + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, sline); + epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); + + // f_v is the floating point result and k_v is the integer result + f_v = qhmath_hvx_vsf_floor_vsf(epsilon_v); + k_v = qhmath_hvx_vw_truncate_vsf(f_v); + + x_qf32_v = Q6_Vqf32_vadd_VsfVsf(sline, zero_v); + + // x = x - f_v * logn2; + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2); + x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v); + // normalize before every QFloat's vmpy + x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); + + // z = x * x; + z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); + z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); + + x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); + + // y = E4 + E5 * x; + E_const = Q6_V_vsplat_R(COEFF_EXP_5); + y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); + E_const = Q6_V_vsplat_R(COEFF_EXP_4); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E3 + y * x; + E_const = Q6_V_vsplat_R(COEFF_EXP_3); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E2 + y * x; + E_const = Q6_V_vsplat_R(COEFF_EXP_2); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E1 + y * x; + E_const = Q6_V_vsplat_R(COEFF_EXP_1); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E0 + y * x; + E_const = Q6_V_vsplat_R(COEFF_EXP_0); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = x + y * z; + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = y + 1.0; + E_const = Q6_V_vsplat_R(0x3f800000); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + + //insert exponents + // y = ldexpf(y, k); + // y_v += k_v; // qf32 + // modify exponent + y_v = Q6_Vsf_equals_Vqf32(y_v); + + // add k_v to the exponent of y_v + HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1); + + y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, 24); + + y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent); + + // exponent cannot be negative; if overflow is detected, result is set to zero + HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent); + + y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, 23); + + y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v); + + return y_v; +} + +/* + * This function computes the exponent on all IEEE 16-bit float elements of an HVX_Vector + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_Vector qhmath_hvx_exp_vhf(HVX_Vector sline) { + HVX_Vector z_qf16_v; + HVX_Vector x_qf16_v; + HVX_Vector y_v; + HVX_Vector k_v; + HVX_Vector f_v; + HVX_Vector tmp_v; + HVX_Vector log2e = Q6_Vh_vsplat_R(LOG2E_HF); + HVX_Vector logn2 = Q6_Vh_vsplat_R(LOGN2_HF); + HVX_Vector E_const; + HVX_Vector zero_v = Q6_V_vzero(); + + // 1) clipping + uint input + // if (x > MAXLOG) + // return (MAXNUM); + // if (x < MINLOG) + // return (0.0); + + // 2) round to int + // k = (int) (x * log2e); + // f = (float) k; + // k = Q6_R_convert_sf2w_R(log2e * x); //f = floorf( log2e * x + 0.5); + // f = Q6_R_convert_w2sf_R(k); //k = (int)f; + + tmp_v = Q6_Vqf16_vmpy_VhfVhf(log2e, sline); + // float16's 0.5 is 0x3800 + HVX_Vector cp5_v = Q6_Vh_vsplat_R(0x3800); + tmp_v = Q6_Vqf16_vadd_Vqf16Vhf(tmp_v, cp5_v); + tmp_v = Q6_Vhf_equals_Vqf16(tmp_v); + + // f_v is the floating point result and k_v is the integer result + f_v = qhmath_hvx_vhf_floor_vhf(tmp_v); + k_v = qhmath_hvx_vh_truncate_vhf(f_v); + + x_qf16_v = Q6_Vqf16_vadd_VhfVhf(sline, zero_v); + + // x = x - f * logn2; + tmp_v = Q6_Vqf16_vmpy_VhfVhf(f_v, logn2); + x_qf16_v = Q6_Vqf16_vsub_Vqf16Vqf16(x_qf16_v, tmp_v); + + // normalize before every QFloat's vmpy + x_qf16_v = Q6_Vqf16_vadd_Vqf16Vhf(x_qf16_v, zero_v); + + // z = x * x; + z_qf16_v = Q6_Vqf16_vmpy_Vqf16Vqf16(x_qf16_v, x_qf16_v); + z_qf16_v = Q6_Vqf16_vadd_Vqf16Vhf(z_qf16_v, zero_v); + + // y = E4 + E5 * x; + E_const = Q6_Vh_vsplat_R(COEFF_EXP_5_HF); + y_v = Q6_Vqf16_vmpy_Vqf16Vhf(x_qf16_v, E_const); + E_const = Q6_Vh_vsplat_R(COEFF_EXP_4_HF); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, E_const); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, zero_v); + + // y = E3 + y * x; + E_const = Q6_Vh_vsplat_R(COEFF_EXP_3_HF); + y_v = Q6_Vqf16_vmpy_Vqf16Vqf16(y_v, x_qf16_v); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, E_const); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, zero_v); + + // y = E2 + y * x; + E_const = Q6_Vh_vsplat_R(COEFF_EXP_2_HF); + y_v = Q6_Vqf16_vmpy_Vqf16Vqf16(y_v, x_qf16_v); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, E_const); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, zero_v); + + // y = E1 + y * x; + E_const = Q6_Vh_vsplat_R(COEFF_EXP_1_HF); + y_v = Q6_Vqf16_vmpy_Vqf16Vqf16(y_v, x_qf16_v); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, E_const); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, zero_v); + + // y = E0 + y * x; + E_const = Q6_Vh_vsplat_R(COEFF_EXP_0_HF); + y_v = Q6_Vqf16_vmpy_Vqf16Vqf16(y_v, x_qf16_v); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, E_const); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, zero_v); + + // y = x + y * z; + y_v = Q6_Vqf16_vmpy_Vqf16Vqf16(y_v, z_qf16_v); + y_v = Q6_Vqf16_vadd_Vqf16Vqf16(y_v, x_qf16_v); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, zero_v); + + // y = y + 1.0; + E_const = Q6_Vh_vsplat_R(0x3C00); + y_v = Q6_Vqf16_vadd_Vqf16Vhf(y_v, E_const); + + // insert exponents + // y = ldexpf(y, k); + // y_v += k_v; // qf32 + // modify exponent + y_v = Q6_Vhf_equals_Vqf16(y_v); + + // add k_v to the exponent of y_v + // shift away sign bit + HVX_Vector y_v_exponent = Q6_Vh_vasl_VhR(y_v, 1); + + // shift back by sign bit + 10-bit mantissa + y_v_exponent = Q6_Vuh_vlsr_VuhR(y_v_exponent, 11); + + y_v_exponent = Q6_Vh_vadd_VhVh(k_v, y_v_exponent); + + // exponent cannot be negative; if overflow is detected, result is set to zero + HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VhVh(zero_v, y_v_exponent); + + // max IEEE hf exponent; if overflow detected, result is set to infinity + HVX_Vector exp_max_v = Q6_Vh_vsplat_R(0x1e); + // INF in 16-bit float is 0x7C00 + HVX_Vector inf_v = Q6_Vh_vsplat_R(0x7C00); + HVX_VectorPred qy_v_overflow_exponent = Q6_Q_vcmp_gt_VhVh(y_v_exponent, exp_max_v); + + // update exponent + y_v = Q6_Vh_vaslacc_VhVhR(y_v, k_v, 10); + + // clip to min/max values + y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v); + y_v = Q6_V_vmux_QVV(qy_v_overflow_exponent, inf_v, y_v); + + return y_v; +} + +inline HVX_VectorPair_x4 qhmath_load_div_sf_ltu() { + /* Coefficients in float representation */ + constexpr const float c0_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 3.882601794814435, + 3.6625422144222575, + 3.464451548227971, + 3.2869700047974098, + 3.126105117815294, + 2.9797652947122333, + 2.846287833147896, + 2.7247270166228237, + 2.614282526778659, + 2.5119448279766914, + 2.4168240690138916, + 2.3287715099556494, + 2.2470044371606255, + 2.1705097010458525, + 2.0993232550771013, + 2.032425103348979, + }; + constexpr const float c1_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + -5.65213466274883, + -5.029649818173625, + -4.500359068222728, + -4.051125252469975, + -3.6643282495304743, + -3.3293252513210945, + -3.0377500909629918, + -2.78384542029156, + -2.562751394984757, + -2.3660481944625364, + -2.1902579830702398, + -2.033579850063907, + -1.8932880190031018, + -1.7665817851802996, + -1.6526109646324616, + -1.5489652830974667, + }; + constexpr const float c2_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 3.6564123863772062, + 3.0693863078484034, + 2.5979108429264546, + 2.2188401136904137, + 1.90879196515026, + 1.6531365145318937, + 1.4408072849395228, + 1.2640160009581791, + 1.1164726565567085, + 0.9904366133906549, + 0.8821387892416702, + 0.7892039810345458, + 0.7089644931002874, + 0.6390020714403465, + 0.5781761255999769, + 0.5246475096790261, + }; + constexpr const float c3_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + -0.8868796162009371, + -0.7023245532864408, + -0.5623148115716742, + -0.45568061400557225, + -0.3728293181808119, + -0.30778916969628956, + -0.25624427383670373, + -0.21520836864975557, + -0.18238585316003267, + -0.1554651987039696, + -0.133224398745864, + -0.11484835534787588, + -0.09954996553138899, + -0.08667244996867919, + -0.07585106425203664, + -0.06663557250850614, + }; + + /* Load coefficients */ + HVX_Vector c0_coeff_v = *((HVX_Vector *) (c0_coeffs)); + HVX_Vector c1_coeff_v = *((HVX_Vector *) (c1_coeffs)); + HVX_Vector c2_coeff_v = *((HVX_Vector *) (c2_coeffs)); + HVX_Vector c3_coeff_v = *((HVX_Vector *) (c3_coeffs)); + + /* Split 32-bit coefficients to lower and upper part in order to obtain them later with VLUT16. */ + hexagon::HVX_VectorPair_x4 result; + result.val[0] = Q6_Wuw_vzxt_Vuh(c0_coeff_v); + result.val[1] = Q6_Wuw_vzxt_Vuh(c1_coeff_v); + result.val[2] = Q6_Wuw_vzxt_Vuh(c2_coeff_v); + result.val[3] = Q6_Wuw_vzxt_Vuh(c3_coeff_v); + + return result; +} + +inline HVX_Vector qhmath_hvx_div_vf(HVX_Vector num, HVX_Vector denom, HVX_VectorPair_x4 coeffs) { + HVX_Vector sline1; + HVX_Vector sline2; + HVX_Vector norm_factor; + HVX_Vector tmp_v; + HVX_Vector idx1_v; + HVX_Vector idx2_v; + HVX_Vector output_v; + HVX_Vector input_shifted_v_qf32; + HVX_Vector input_scaled_v_qf32; + HVX_VectorPair c0_coeff_vp; + HVX_VectorPair c1_coeff_vp; + HVX_VectorPair c2_coeff_vp; + HVX_VectorPair c3_coeff_vp; + + /* + * Splat scale factor in order to be used later for finding indexes of coefficients. + * Scale factor is represented in IEEE 16-bit floating-point format and it is + * calculated using the following formula: + * scale_factor = (16.0 / (b0 - a0)) + * NOTE: Calculated value is slightly decreased in order to avoid out of bound + * indexes during VLUT lookup. + */ + HVX_Vector scale_v = Q6_V_vsplat_R(0x417ffffe); + + /* + * Vector of zeroes used as neutral element in sf to qf32 conversions. + * NOTE: Some of conversions (i.e conversion of scale factor and coefficients) + * can be avoided in real-time, but this is not done in order to don't + * sacrify code readibility in expense of insignificant performance improvement. + */ + HVX_Vector zero_v_sf = Q6_V_vzero(); + + /* Set sign = 0, exp = 254, mant = 0 */ + HVX_Vector exp = Q6_V_vsplat_R(0x7F000000); + + /* Set mask for sign and exponent */ + HVX_Vector signexp_mask = Q6_V_vsplat_R(0xFF800000); + + /* Mask for extracting only 4 bits of mantissa */ + HVX_Vector mask_idx1_v = Q6_V_vsplat_R(0x0000000F); + HVX_Vector mask_idx2_v = Q6_V_vsplat_R(0x00000010); + + /* 16.0 in IEEE 16-bit floating-point representation */ + HVX_Vector const16_0_v_sf = Q6_V_vsplat_R(0x41800000); + + /* + * Prepare vector of input_min values, that is used later in shifting input range. + * input_min is low boundary of specified input range. + */ + HVX_Vector input_min_v_f = Q6_V_vsplat_R(0x3f800000); + + /* Convert scale factor from sf to q32. Use the same vector for both formats */ + scale_v = Q6_Vqf32_vadd_VsfVsf(scale_v, zero_v_sf); + + /* Calculate normalization factor */ + norm_factor = Q6_V_vand_VV(denom, signexp_mask); + norm_factor = Q6_Vw_vsub_VwVw(exp, norm_factor); + + /* Normalize denominators */ + sline2 = Q6_Vqf32_vmpy_VsfVsf(denom, norm_factor); + sline2 = Q6_Vsf_equals_Vqf32(sline2); + + /* Convert normalization factor and numerator to qf32 */ + norm_factor = Q6_Vqf32_vadd_VsfVsf(norm_factor, zero_v_sf); + sline1 = Q6_Vqf32_vadd_VsfVsf(num, zero_v_sf); + + /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ + input_shifted_v_qf32 = Q6_Vqf32_vsub_VsfVsf(sline2, input_min_v_f); + + /* + * Scale shifted input range from [0, input_max - input_min] to [0,16.0) + * in order to get corresponding coefficient indexes + */ + input_scaled_v_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(input_shifted_v_qf32, scale_v); + + /* + * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) + * to [16.0,32.0) in order to convert float indexes to integer values. + * Float values, represented in IEEE 754, in range [16.0,32.0] have the + * same exponent, which means 4 MSB of mantissa carry information about + * integer index. + */ + input_scaled_v_qf32 = Q6_Vqf32_vadd_Vqf32Vsf(input_scaled_v_qf32, const16_0_v_sf); + + /* Convert back from qf32 to sf in order to extract integer index */ + tmp_v = Q6_Vsf_equals_Vqf32(input_scaled_v_qf32); + + /* Only 4 MSB bits of mantissa represent segment index */ + idx1_v = Q6_Vuw_vlsr_VuwR(tmp_v, 19); + + idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); + idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); + idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); + + /* Obtain the polynomial coefficients from lookup table */ + c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(coeffs.val[0]), 1); + c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(coeffs.val[0]), 1); + c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(coeffs.val[1]), 1); + c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(coeffs.val[1]), 1); + c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(coeffs.val[2]), 1); + c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(coeffs.val[2]), 1); + c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(coeffs.val[3]), 1); + c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(coeffs.val[3]), 1); + + /* Perform evaluation of polynomial using Horner's method */ + output_v = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(c3_coeff_vp), sline2); + output_v = Q6_Vqf32_vadd_Vqf32Vsf(output_v, Q6_V_lo_W(c2_coeff_vp)); + output_v = Q6_Vsf_equals_Vqf32(output_v); + + output_v = Q6_Vqf32_vmpy_VsfVsf(output_v, sline2); + output_v = Q6_Vqf32_vadd_Vqf32Vsf(output_v, Q6_V_lo_W(c1_coeff_vp)); + output_v = Q6_Vsf_equals_Vqf32(output_v); + + output_v = Q6_Vqf32_vmpy_VsfVsf(output_v, sline2); + output_v = Q6_Vqf32_vadd_Vqf32Vsf(output_v, Q6_V_lo_W(c0_coeff_vp)); + + /* Multiply result by same normalization factor applied to input earlier */ + output_v = Q6_Vqf32_vmpy_Vqf32Vqf32(output_v, norm_factor); + + /* Calculate num * 1/den */ + output_v = Q6_Vqf32_vmpy_Vqf32Vqf32(output_v, sline1); + + return Q6_Vsf_equals_Vqf32(output_v); +} + +inline HVX_VectorPair_x4 qhmath_load_div_hf_ltu() { + /* Coefficients in float representation */ + constexpr const float c0_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 3.8807721943716516, + 3.6618209528616856, + 3.4657742282097708, + 3.2853461610022414, + 3.1229570908314015, + 2.976379865829892, + 2.8438614274889833, + 2.723793061029549, + 2.613859154046634, + 2.5119508509784287, + 2.4167270706641473, + 2.3286721812015188, + 2.2462659531748064, + 2.1692490555028736, + 2.0981551828382417, + 2.0319234960945, + }; + constexpr const float c1_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + -5.646783581176797, + -5.027704168781284, + -4.5037889029173535, + -4.0470997487793445, + -3.6569569537789364, + -3.3217563552211695, + -3.03258650196419, + -2.781935505534812, + -2.5619261358961922, + -2.3660577978107398, + -2.190083163030879, + -2.033405493468989, + -1.8920413948588666, + -1.7645298754188785, + -1.6507730169513504, + -1.5482028127706613, + }; + constexpr const float c2_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 3.6511964773849632, + 3.0676375988553106, + 2.6008750952258324, + 2.215514199159397, + 1.9030391013295935, + 1.6474963735373633, + 1.4371447652517673, + 1.2627141904289978, + 1.11593649827749, + 0.9904415490260164, + 0.882033772823834, + 0.7891019704346331, + 0.7082630629776306, + 0.6378888508693012, + 0.5772121720355701, + 0.524261196551401, + }; + constexpr const float c3_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = { + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + -0.8851851956149304, + -0.7018008948429424, + -0.5631686602024177, + -0.4547647803673564, + -0.37133287830029976, + -0.3063883382130307, + -0.255378412302572, + -0.2149126167280633, + -0.18226975346347984, + -0.15546600267845986, + -0.13320337246909697, + -0.11482846255803722, + -0.0994184164975366, + -0.08647114157420362, + -0.07568254923048714, + -0.06657033258736733, + }; + + /* Load coefficients */ + HVX_Vector c0_coeff_v = *((HVX_Vector *) (c0_coeffs)); + HVX_Vector c1_coeff_v = *((HVX_Vector *) (c1_coeffs)); + HVX_Vector c2_coeff_v = *((HVX_Vector *) (c2_coeffs)); + HVX_Vector c3_coeff_v = *((HVX_Vector *) (c3_coeffs)); + + /* Convert coefficients from hf to qf32 format. Use the same vector for both representations */ + HVX_Vector zero_v_hf = Q6_V_vzero(); + c0_coeff_v = Q6_Vqf32_vadd_VsfVsf(c0_coeff_v, zero_v_hf); + c1_coeff_v = Q6_Vqf32_vadd_VsfVsf(c1_coeff_v, zero_v_hf); + c2_coeff_v = Q6_Vqf32_vadd_VsfVsf(c2_coeff_v, zero_v_hf); + c3_coeff_v = Q6_Vqf32_vadd_VsfVsf(c3_coeff_v, zero_v_hf); + + /* Split 32-bit coefficients to lower and upper part in order to obtain them later with VLUT16. */ + hexagon::HVX_VectorPair_x4 result; + result.val[0] = Q6_Wuw_vzxt_Vuh(c0_coeff_v); + result.val[1] = Q6_Wuw_vzxt_Vuh(c1_coeff_v); + result.val[2] = Q6_Wuw_vzxt_Vuh(c2_coeff_v); + result.val[3] = Q6_Wuw_vzxt_Vuh(c3_coeff_v); + + return result; +} + +inline HVX_Vector qhmath_hvx_div_vhf(HVX_Vector num, HVX_Vector denom, HVX_VectorPair_x4 coeffs) { + HVX_Vector sline2; + HVX_Vector norm_factor; + HVX_VectorPair norm_factor_qf32; + HVX_Vector tmp_v; + HVX_Vector idx1_v; + HVX_Vector idx2_v; + HVX_DV output_dv; + HVX_Vector input_shifted_v_hf; + HVX_Vector input_scaled_v; + HVX_VectorPair input_vp_qf32; + HVX_VectorPair input_n_vp_qf32; + HVX_VectorPair c0_coeff_vp; + HVX_VectorPair c1_coeff_vp; + HVX_VectorPair c2_coeff_vp; + HVX_VectorPair c3_coeff_vp; + + /* + * Splat scale factor in order to be used later for finding indexes of coefficients. + * Scale factor is represented in IEEE 16-bit floating-point format and it is + * calculated using the following formula: + * scale_factor = (convert_sf_to_hf) (16.0 / (b0 - a0)) + * NOTE: Calculated value is slightly decreased in order to avoid out of bound + * indexes during VLUT lookup. + */ + HVX_Vector scale_v = Q6_Vh_vsplat_R(0x4bfb); + + /* Vector of ones used as mpy neutral element in conversions from hf vector to qf32 vector pair */ + HVX_Vector one_v_hf = Q6_Vh_vsplat_R(0x3c00); + + /* + * Vector of zeroes used as neutral element in hf to qf16 conversions. + * NOTE: Some of conversions (i.e conversion of scale factor and coefficients) + * can be avoided in real-time, but this is not done in order to don't + * sacrify code readibility in expense of insignificant performance improvement. + */ + HVX_Vector zero_v_hf = Q6_V_vzero(); + + /* Set sign = 0, exp = 30, mant = 0 */ + HVX_Vector exp = Q6_Vh_vsplat_R(0x7800); + + /* Set mask for sign and exponent */ + HVX_Vector signexp_mask = Q6_Vh_vsplat_R(0xFC00); + + /* Mask for extracting only 4 bits of mantissa */ + HVX_Vector mask_idx1_v = Q6_Vh_vsplat_R(0x000F); + HVX_Vector mask_idx2_v = Q6_V_vsplat_R(0x00001010); + + /* 16.0 in IEEE 16-bit floating-point representation */ + HVX_Vector const16_0_v_hf = Q6_Vh_vsplat_R(0x4c00); + + /* + * Prepare vector of input_min values, that is used later in shifting input range. + * input_min is low boundary of specified input range. + */ + HVX_Vector input_min_v_hf = Q6_Vh_vsplat_R(0x3c00); + + /* Convert scale factor from hf to q16. Use the same vector for both formats */ + scale_v = Q6_Vqf16_vadd_VhfVhf(scale_v, zero_v_hf); + + /* Calculate normalization factor */ + norm_factor = Q6_V_vand_VV(denom, signexp_mask); + norm_factor = Q6_Vh_vsub_VhVh(exp, norm_factor); + + /* Normalize denominators */ + sline2 = Q6_Vqf16_vmpy_VhfVhf(denom, norm_factor); + + /* Convert normalization factor to qf32 */ + norm_factor_qf32 = Q6_Wqf32_vmpy_VhfVhf(norm_factor, one_v_hf); + + /* Shift input range from [input_min, input_max] to [0, input_max - input_min] */ + tmp_v = Q6_Vh_vdeal_Vh(sline2); + input_shifted_v_hf = Q6_Vqf16_vsub_Vqf16Vhf(tmp_v, input_min_v_hf); + + /* + * Scale shifted input range from [0, input_max - input_min] to [0,16.0) + * in order to get corresponding coefficient indexes + */ + input_scaled_v = Q6_Vqf16_vmpy_Vqf16Vqf16(input_shifted_v_hf, scale_v); + + /* + * VLUT 16 requires integer indexes. Shift scaled input range from [0,16.0) + * to [16.0,32.0) in order to convert float indexes to integer values. + * Float values, represented in IEEE 754, in range [16.0,32.0] have the + * same exponent, which means 4 MSB of mantissa carry information about + * integer index. + */ + /* Use the same input_scaled_v vector for hf and qf16 representation */ + input_scaled_v = Q6_Vqf16_vadd_Vqf16Vhf(input_scaled_v, const16_0_v_hf); + + /* Convert back from qf16 to hf in order to extract integer index */ + tmp_v = Q6_Vhf_equals_Vqf16(input_scaled_v); + + /* Only 4 MSB bits of mantissa represent segment index */ + idx1_v = Q6_Vuh_vlsr_VuhR(tmp_v, 6); + + /* Ensure only 4 MSB bits of mantissa are used as indexes */ + idx1_v = Q6_V_vand_VV(idx1_v, mask_idx1_v); + + idx1_v = Q6_Vb_vshuff_Vb(idx1_v); + idx1_v = Q6_V_vor_VV(idx1_v, mask_idx2_v); + idx2_v = Q6_Vw_vasl_VwR(idx1_v, 16); + + /* Obtain the polynomial coefficients from lookup table */ + c0_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(coeffs.val[0]), 1); + c0_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c0_coeff_vp, idx2_v, Q6_V_hi_W(coeffs.val[0]), 1); + c1_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(coeffs.val[1]), 1); + c1_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c1_coeff_vp, idx2_v, Q6_V_hi_W(coeffs.val[1]), 1); + c2_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(coeffs.val[2]), 1); + c2_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c2_coeff_vp, idx2_v, Q6_V_hi_W(coeffs.val[2]), 1); + c3_coeff_vp = Q6_Wh_vlut16_VbVhR(idx1_v, Q6_V_lo_W(coeffs.val[3]), 1); + c3_coeff_vp = Q6_Wh_vlut16or_WhVbVhR(c3_coeff_vp, idx2_v, Q6_V_hi_W(coeffs.val[3]), 1); + + /* Convert inputs from hf vector to qf32 vector pair for Horner's method*/ + input_vp_qf32 = Q6_Wqf32_vmpy_Vqf16Vhf(sline2, one_v_hf); + input_n_vp_qf32 = Q6_Wqf32_vmpy_VhfVhf(num, one_v_hf); + + /* Perform evaluation of polynomial using Horner's method */ + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(c3_coeff_vp), Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c2_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c1_coeff_vp)); + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_vp_qf32)); + output_dv.V.lo = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(c0_coeff_vp)); + + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(c3_coeff_vp), Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c2_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c1_coeff_vp)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vadd_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(c0_coeff_vp)); + + /* Multiply result by same normalization factor applied to input earlier */ + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(norm_factor_qf32)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(norm_factor_qf32)); + + /* Calculate num * 1/den */ + output_dv.V.lo = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.lo, Q6_V_lo_W(input_n_vp_qf32)); + output_dv.V.hi = Q6_Vqf32_vmpy_Vqf32Vqf32(output_dv.V.hi, Q6_V_hi_W(input_n_vp_qf32)); + + return Q6_Vhf_equals_Wqf32(output_dv.VV); +} + +/* + * This function converts a vector of IEEE float elements to a vector of qf32 elements + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_Vector qhmath_hvx_vqf32_convert_vsf(HVX_Vector vin) { + return Q6_Vqf32_vadd_VsfVsf(vin, Q6_V_vzero()); +} + +/* + * This function converts a vector of IEEE half float elements to a vector of qf16 elements + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_Vector qhmath_hvx_vqf16_convert_vhf(HVX_Vector vin) { + return Q6_Vqf16_vadd_VhfVhf(vin, Q6_V_vzero()); +} + +/* + * This function converts a pair of vectors of qf32 elements to a vector of IEEE half float elements + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_Vector qhmath_hvx_vhf_convert_vqf32(HVX_VectorPair vin_vp) { + return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(vin_vp)); +} + +/* + * This function converts a vector of qf16 elements to a pair of vectors of qf32 elements + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_VectorPair qhmath_hvx_vqf32_convert_vqf16(HVX_Vector vxl) { + HVX_VectorPair vxw_vp, exponent_vp; + HVX_Vector mantissa_mask = Q6_Vh_vsplat_R(0xffe0); + HVX_Vector exp_mask = Q6_Vh_vsplat_R(0x1f); + HVX_Vector exp_offset = Q6_Vh_vsplat_R(0x70); + HVX_Vector mant32_shift = Q6_Vh_vsplat_R(0x10); + HVX_Vector reql, reqh, vxl_w, vxh_w, mantissa; + HVX_Vector el_exponent, eh_exponent; + + el_exponent = Q6_V_vand_VV(exp_mask, vxl); + // Obtain the mantissa part: bits (5-15) + mantissa = Q6_V_vand_VV(mantissa_mask, vxl); + // Convert qf16 biassed exponent to qf32 biased exponent + // new exp = exp + ( 127 (qf32 bias) -15(qf16 biass) ) = 112 + el_exponent = Q6_Vh_vadd_VhVh(exp_offset, el_exponent); + + vxw_vp = Q6_Ww_vunpack_Vh(mantissa); + vxl_w = Q6_V_lo_W(vxw_vp); + vxh_w = Q6_V_hi_W(vxw_vp); + + exponent_vp = Q6_Ww_vunpack_Vh(el_exponent); + el_exponent = Q6_V_lo_W(exponent_vp); + eh_exponent = Q6_V_hi_W(exponent_vp); + // Convert q16 mantiss to q32 mantissa + reql = Q6_Vw_vasl_VwVw(vxl_w, mant32_shift); + reqh = Q6_Vw_vasl_VwVw(vxh_w, mant32_shift); + // Add the exponent + vxl_w = Q6_Vw_vadd_VwVw(reql, el_exponent); + vxh_w = Q6_Vw_vadd_VwVw(reqh, eh_exponent); + + return Q6_W_vcombine_VV(vxh_w, vxl_w); +} + +inline HVX_VectorPair hvx_vqf32_convert_vhf(HVX_Vector vxl) { + return qhmath_hvx_vqf32_convert_vqf16(qhmath_hvx_vqf16_convert_vhf(vxl)); +} + +inline HVX_Vector_x2 hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) { + HVX_VectorPair res = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vxl), one); + return { + Q6_Vsf_equals_Vqf32(Q6_V_lo_W(res)), + Q6_Vsf_equals_Vqf32(Q6_V_hi_W(res)), + }; +} + +} // namespace hexagon::vec::math diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.hpp b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp index 051255c9b7..92cb8ed999 100644 --- a/ggml/src/ggml-qnn/npu/device/vec_ops.hpp +++ b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp @@ -1,13 +1,29 @@ #pragma once +#include "hexagon_npu.h" + #include #include -#include "hexagon_npu.h" - namespace hexagon { +template struct HEXAGON_pack { + T val[N]; +}; + +using HVX_Vector_x2 = std::pair; +using HVX_VectorPair_x4 = HEXAGON_pack; + +typedef union { + HVX_VectorPair VV; + + struct { + HVX_Vector lo; + HVX_Vector hi; + } V; +} HVX_DV; + constexpr const size_t kBytesPerVector = sizeof(HVX_Vector); // 128 for v73 constexpr const size_t kAlignMask = kBytesPerVector - 1; @@ -63,67 +79,6 @@ inline void l2fetch_row(const uint8_t * row_ptr, size_t bytes) { hexagon::l2fetch(row_ptr, kBytesPerVector, kBytesPerVector, l2fetch_vectors, 0); } -/* - * This function converts a vector of IEEE float elements to a vector of qf32 elements - * See also: libs\qfe\inc\qhmath_hvx_convert.h - */ -inline HVX_Vector qhmath_hvx_vqf32_convert_vsf(HVX_Vector vin) { - return Q6_Vqf32_vadd_VsfVsf(vin, Q6_V_vzero()); -} - -/* - * This function converts a vector of IEEE half float elements to a vector of qf16 elements - * See also: libs\qfe\inc\qhmath_hvx_convert.h - */ -inline HVX_Vector qhmath_hvx_vqf16_convert_vhf(HVX_Vector vin) { - return Q6_Vqf16_vadd_VhfVhf(vin, Q6_V_vzero()); -} - -/* - * This function converts a pair of vectors of qf32 elements to a vector of IEEE half float elements - * See also: libs\qfe\inc\qhmath_hvx_convert.h - */ -inline HVX_Vector qhmath_hvx_vhf_convert_vqf32(HVX_VectorPair vin_vp) { - return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(vin_vp)); -} - -/* - * This function converts a vector of qf16 elements to a pair of vectors of qf32 elements - * See also: libs\qfe\inc\qhmath_hvx_convert.h - */ -inline HVX_VectorPair qhmath_hvx_vqf32_convert_vqf16(HVX_Vector vxl) { - HVX_VectorPair vxw_vp, exponent_vp; - HVX_Vector mantissa_mask = Q6_Vh_vsplat_R(0xffe0); - HVX_Vector exp_mask = Q6_Vh_vsplat_R(0x1f); - HVX_Vector exp_offset = Q6_Vh_vsplat_R(0x70); - HVX_Vector mant32_shift = Q6_Vh_vsplat_R(0x10); - HVX_Vector reql, reqh, vxl_w, vxh_w, mantissa; - HVX_Vector el_exponent, eh_exponent; - - el_exponent = Q6_V_vand_VV(exp_mask, vxl); - // Obtain the mantissa part: bits (5-15) - mantissa = Q6_V_vand_VV(mantissa_mask, vxl); - // Convert qf16 biassed exponent to qf32 biased exponent - // new exp = exp + ( 127 (qf32 bias) -15(qf16 biass) ) = 112 - el_exponent = Q6_Vh_vadd_VhVh(exp_offset, el_exponent); - - vxw_vp = Q6_Ww_vunpack_Vh(mantissa); - vxl_w = Q6_V_lo_W(vxw_vp); - vxh_w = Q6_V_hi_W(vxw_vp); - - exponent_vp = Q6_Ww_vunpack_Vh(el_exponent); - el_exponent = Q6_V_lo_W(exponent_vp); - eh_exponent = Q6_V_hi_W(exponent_vp); - // Convert q16 mantiss to q32 mantissa - reql = Q6_Vw_vasl_VwVw(vxl_w, mant32_shift); - reqh = Q6_Vw_vasl_VwVw(vxh_w, mant32_shift); - // Add the exponent - vxl_w = Q6_Vw_vadd_VwVw(reql, el_exponent); - vxh_w = Q6_Vw_vadd_VwVw(reqh, eh_exponent); - - return Q6_W_vcombine_VV(vxh_w, vxl_w); -} - template inline void q6op_vstu_variable_ARV(void * addr, HVX_Vector vin) { vin = Q6_V_vlalign_VVR(vin, vin, (size_t) addr); //rotate as needed. uint32_t left_off = unaligned_bytes(addr); @@ -157,20 +112,6 @@ inline void q6op_vstu_variable_ARV(void * addr, int n, HVX_Vector vin) { Q6_vmaskedstorenq_QAV(qL_not, (HVX_Vector *) addr, vin); } -inline HVX_VectorPair hvx_vqf32_convert_vhf(HVX_Vector vxl) { - return qhmath_hvx_vqf32_convert_vqf16(qhmath_hvx_vqf16_convert_vhf(vxl)); -} - -using HVX_Vector_Dual = std::pair; - -inline HVX_Vector_Dual hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) { - HVX_VectorPair res = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vxl), one); - return { - Q6_Vsf_equals_Vqf32(Q6_V_lo_W(res)), - Q6_Vsf_equals_Vqf32(Q6_V_hi_W(res)), - }; -} - inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) { constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float); static_assert(kFloatsPerVector == 32, "kFloatsPerVector should be 32"); @@ -247,6 +188,7 @@ inline HVX_Vector hvx_passthru(HVX_Vector src, HVX_UVector *, HVX_Vector) { } // namespace hexagon +#include "vec_math.inl" #include "vec_ops.inl" namespace hexagon { @@ -313,8 +255,8 @@ inline HVX_Vector vec_dot_product_vqf32_f32_f32(const float * src0, const float inline HVX_Vector vec_dot_product_aligned_vqf32_f32_f32(const float * src0, const float * src1, size_t count) { using namespace hexagon::vec; - return vec_dot_product_aligned_impl(src0, src1, - count); + return vec_dot_product_aligned_impl( + src0, src1, count); } inline float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) { @@ -324,23 +266,25 @@ inline float vec_dot_product_f32_f32(const float * src0, const float * src1, siz inline float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) { using namespace hexagon::vec; - return vec_dot_product_aligned_impl(src0, src1, - count); + return vec_dot_product_aligned_impl( + src0, src1, count); } inline bool is_f32_f32_dot_product_aligned(const float * src0, const float * src1, size_t count) { return is_dot_product_aligned(src0, src1, count); } -inline HVX_Vector vec_dot_product_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, - size_t count) { +inline HVX_Vector vec_dot_product_vqf16_f16_f16(const npu_device_fp16_t * src0, + const npu_device_fp16_t * src1, + size_t count) { using namespace hexagon::vec; return vec_dot_product_impl( src0, src1, count); } -inline HVX_Vector vec_dot_product_aligned_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, - size_t count) { +inline HVX_Vector vec_dot_product_aligned_vqf16_f16_f16(const npu_device_fp16_t * src0, + const npu_device_fp16_t * src1, + size_t count) { using namespace hexagon::vec; return vec_dot_product_aligned_impl( src0, src1, count); @@ -352,41 +296,68 @@ inline float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_d src0, src1, count); } -inline float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, - size_t count) { +inline float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, + const npu_device_fp16_t * src1, + size_t count) { using namespace hexagon::vec; return vec_dot_product_aligned_impl( src0, src1, count); } -inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, - size_t count) { +inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0, + const npu_device_fp16_t * src1, + size_t count) { return is_dot_product_aligned(src0, src1, count); } inline HVX_Vector vec_dot_product_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { using namespace hexagon::vec; - return vec_dot_product_mixed_impl(src0, src1, count); + using namespace hexagon::vec::math; + return vec_dot_product_mixed_impl(src0, src1, count); } -inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1, - size_t count) { +inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0, + const float * src1, + size_t count) { using namespace hexagon::vec; - return vec_dot_product_mix_aligned_impl(src0, src1, count); + using namespace hexagon::vec::math; + return vec_dot_product_mix_aligned_impl(src0, src1, count); } inline float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { using namespace hexagon::vec; - return vec_dot_product_mixed_impl(src0, src1, count); } inline float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { using namespace hexagon::vec; - return vec_dot_product_mix_aligned_impl(src0, src1, count); + using namespace hexagon::vec::math; + return vec_dot_product_mix_aligned_impl(src0, src1, count); } inline bool is_f16_f32_dot_product_aligned(const npu_device_fp16_t * src0, const float * src1, size_t count) { @@ -409,4 +380,35 @@ _TReturn type_erase_dot_func(const void * src0, const void * src1, size_t count) return _DotFunc(src0_typed, src1_typed, count); } +inline HVX_Vector vec_silu_f32_f32(HVX_Vector x, HVX_VectorPair_x4 coeff) { + using namespace hexagon::vec::math; + + HVX_Vector one = Q6_V_vsplat_R(0x3F800000); + + // x/(1.0f + expf(-x)); + HVX_Vector exp_neg_x = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(Q6_V_vzero(), x)); + HVX_Vector denom = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(qhmath_hvx_exp_vf(exp_neg_x), one)); + return qhmath_hvx_div_vf(x, denom, coeff); +} + +inline HVX_Vector vec_silu_f16_f16(HVX_Vector x, HVX_VectorPair_x4 coeff) { + using namespace hexagon::vec::math; + HVX_Vector one = Q6_Vh_vsplat_R(0x3c00); + + // x/(1.0f + expf(-x)); + HVX_Vector exp_neg_x = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(Q6_V_vzero(), x)); + HVX_Vector denom = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_VhfVhf(qhmath_hvx_exp_vhf(exp_neg_x), one)); + return qhmath_hvx_div_vhf(x, denom, coeff); +} + +inline HVX_Vector vec_swiglu_f32_f32(HVX_Vector x, HVX_Vector g, HVX_VectorPair_x4 coeff) { + HVX_Vector silu = vec_silu_f32_f32(x, coeff); + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(silu, g)); +} + +inline HVX_Vector vec_swiglu_f16_f16(HVX_Vector x, HVX_Vector g, HVX_VectorPair_x4 coeff) { + HVX_Vector silu = vec_silu_f16_f16(x, coeff); + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(silu, g)); +} + } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.inl b/ggml/src/ggml-qnn/npu/device/vec_ops.inl index f21d6b06d9..854d975edb 100644 --- a/ggml/src/ggml-qnn/npu/device/vec_ops.inl +++ b/ggml/src/ggml-qnn/npu/device/vec_ops.inl @@ -1,15 +1,18 @@ #pragma once +#include "hexagon_npu.h" + #include #include -#include "hexagon_npu.h" - namespace hexagon::vec { -template +template inline _TRet vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) { constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem); @@ -95,8 +98,11 @@ inline _TRet vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size return _ReduceFunc(sum); } -template +template inline _TRet vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) { constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem); @@ -170,8 +176,12 @@ inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) { return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result); } -template inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) { static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1"); @@ -202,8 +212,8 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr HVX_Vector curr0 = src0_vec_ptr[0]; HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV); + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + HVX_Vector_x2 s0_pair = _ExpandFunc(s0, kOneV); HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1); sum0 = _AddFunc(_MpyFunc(s0_pair.first, l1), sum0); @@ -228,8 +238,8 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0; src0_vec_ptr += should_fetch_src0 ? 1 : 0; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV); + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + HVX_Vector_x2 s0_pair = _ExpandFunc(s0, kOneV); const bool has_remaining_src1_vector = src1_vec_ptr_end - src1_vec_ptr > 0; if (has_remaining_src1_vector) { @@ -262,7 +272,7 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr prev1; curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - HVX_Vector_Dual curr0_pair = _ExpandFunc(curr0, kOneV); + HVX_Vector_x2 curr0_pair = _ExpandFunc(curr0, kOneV); curr0 = leftover1 == leftover0 ? curr0_pair.first : curr0_pair.second; sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes1), sum); @@ -271,8 +281,12 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr return _ReduceFunc(sum); } -template inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) { static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1"); @@ -297,16 +311,16 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem HVX_Vector sum3 = Q6_V_vzero(); do { - HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; - HVX_Vector_Dual curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV); - HVX_VectorPair curr10 = reinterpret_cast(src1_vec_ptr)[0]; - sum0 = _AddFunc(_MpyFunc(curr00.first, Q6_V_lo_W(curr10)), sum0); - sum1 = _AddFunc(_MpyFunc(curr00.second, Q6_V_hi_W(curr10)), sum1); + HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; + HVX_Vector_x2 curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV); + HVX_VectorPair curr10 = reinterpret_cast(src1_vec_ptr)[0]; + sum0 = _AddFunc(_MpyFunc(curr00.first, Q6_V_lo_W(curr10)), sum0); + sum1 = _AddFunc(_MpyFunc(curr00.second, Q6_V_hi_W(curr10)), sum1); - HVX_Vector_Dual curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV); - HVX_VectorPair curr11 = reinterpret_cast(src1_vec_ptr)[1]; - sum2 = _AddFunc(_MpyFunc(curr01.first, Q6_V_lo_W(curr11)), sum2); - sum3 = _AddFunc(_MpyFunc(curr01.second, Q6_V_hi_W(curr11)), sum3); + HVX_Vector_x2 curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV); + HVX_VectorPair curr11 = reinterpret_cast(src1_vec_ptr)[1]; + sum2 = _AddFunc(_MpyFunc(curr01.first, Q6_V_lo_W(curr11)), sum2); + sum3 = _AddFunc(_MpyFunc(curr01.second, Q6_V_hi_W(curr11)), sum3); src0_vec_ptr += 2; src1_vec_ptr += 4; @@ -317,8 +331,8 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem } if (src1_vec_ptr_end - src1_vec_ptr > 1) { - HVX_Vector curr0 = src0_vec_ptr[0]; - HVX_Vector_Dual s0_pair = _ExpandFunc(curr0, kOneV); + HVX_Vector curr0 = src0_vec_ptr[0]; + HVX_Vector_x2 s0_pair = _ExpandFunc(curr0, kOneV); HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; sum0 = _AddFunc(_MpyFunc(s0_pair.first, Q6_V_lo_W(curr1)), sum0); @@ -328,7 +342,8 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem return _ReduceFunc(_AddFunc(sum0, sum1)); } -template inline void vec_scale_impl(const _TParam * src, float scale, _TParam * dst, size_t count) { constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TParam); @@ -410,7 +425,7 @@ template inline void vec_zero_impl(_TData * src, size_t count) } template -inline void vec_trans_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) { +inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData * dst, size_t count) { constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData); HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); @@ -496,4 +511,95 @@ inline void vec_trans_op_impl(const _TyData * src0, const _TyData * src1, size_t } } +template +inline void vec_trans_with_param_impl(const _TyData * src0, + const _TyData * src1, + _TyData * dst, + size_t count, + _TyParam param) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData); + + HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); + HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector; + HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); + HVX_Vector * dst_vec_ptr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned + HVX_Vector prev0 = *src0_vec_ptr++; + HVX_Vector prev1 = *src1_vec_ptr++; + + { + while (src0_vec_ptr_end - src0_vec_ptr > 1) { + HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; + HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; + + HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0); + HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1); + dst_vec_ptr[0] = _OpBinaryTransform(l0, l1, param); + + HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0); + HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1); + dst_vec_ptr[1] = _OpBinaryTransform(h0, h1, param); + + prev0 = Q6_V_hi_W(curr0); + prev1 = Q6_V_hi_W(curr1); + src0_vec_ptr += 2; + src1_vec_ptr += 2; + dst_vec_ptr += 2; + } + } + + if (src0_vec_ptr_end - src0_vec_ptr > 0) { + HVX_Vector curr0 = *src0_vec_ptr++; + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + + HVX_Vector curr1 = *src1_vec_ptr++; + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + + dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param); + + prev0 = curr0; + prev1 = curr1; + dst_vec_ptr++; + } + + const size_t leftover = count % kElementsPerVector; + if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) { + // handle the last vector + // see also: + // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 + // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c + bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr); + bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr); + + HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0; + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + + HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1; + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + + dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param); + + src0_vec_ptr += should_fetch_src0 ? 1 : 0; + src1_vec_ptr += should_fetch_src1 ? 1 : 0; + prev0 = curr0; + prev1 = curr1; + dst_vec_ptr++; + } + + if (leftover > 0) { + // handle the leftover elements + const size_t leftover_bytes = leftover * sizeof(_TyData); + HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ? + *src0_vec_ptr : + prev0; + curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + + HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ? + *src1_vec_ptr : + prev1; + curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + + q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1, param)); + } +} + } // namespace hexagon::vec diff --git a/ggml/src/ggml-qnn/npu/device/vtcm_mem.hpp b/ggml/src/ggml-qnn/npu/device/vtcm_mem.hpp index b66ea7f348..9dd640d72f 100644 --- a/ggml/src/ggml-qnn/npu/device/vtcm_mem.hpp +++ b/ggml/src/ggml-qnn/npu/device/vtcm_mem.hpp @@ -1,10 +1,10 @@ #pragma once +#include "util.hpp" + #include #include -#include "util.hpp" - namespace hexagon { class vtcm_mem { diff --git a/ggml/src/ggml-qnn/npu/host/host_device.cpp b/ggml/src/ggml-qnn/npu/host/host_device.cpp index 7b9a13c82f..fca1167282 100644 --- a/ggml/src/ggml-qnn/npu/host/host_device.cpp +++ b/ggml/src/ggml-qnn/npu/host/host_device.cpp @@ -5,13 +5,13 @@ #include #pragma GCC diagnostic pop +#include "graph.hpp" +#include "util.hpp" + #include #include -#include "graph.hpp" -#include "util.hpp" - #define SKEL_URI_DEFINE(arch) ("file:///libhexagon_npu_skel_" arch ".so?npu_device_skel_handle_invoke&_modver=1.0") namespace { @@ -68,8 +68,10 @@ void backend_free(ggml_backend_t backend) { delete get_backend_object(backend); } -bool backend_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, - ggml_tensor * dst) { +bool backend_cpy_tensor_async(ggml_backend_t backend_src, + ggml_backend_t backend_dst, + const ggml_tensor * src, + ggml_tensor * dst) { // TODO: implement this return false; } @@ -194,13 +196,24 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) { boolean supported = false; auto dst_spec = get_spec(op); - auto ret = npu_device_device_support_op(_device_handle, npu_op, &dst_spec, srcs, i, &supported); + + npu_device_tensor_op_spec npu_op_spec = { npu_op, {} }; + static_assert(sizeof(npu_op_spec.params) <= sizeof(op->op_params), + "npu_op_spec.params size should less than op->op_params size"); + memcpy(npu_op_spec.params, op->op_params, sizeof(npu_op_spec.params)); + auto ret = npu_device_device_support_op(_device_handle, &npu_op_spec, &dst_spec, srcs, i, &supported); if (ret != AEE_SUCCESS || !supported) { #ifndef NDEBUG auto * src0_type = i ? ggml_type_name(op->src[0]->type) : "null"; auto * src1_type = (i > 1) ? ggml_type_name(op->src[1]->type) : "null"; - LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", get_name(), ggml_op_desc(op), - ggml_type_name(op->type), src0_type, src1_type, ret, supported); + LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", + get_name(), + ggml_op_desc(op), + ggml_type_name(op->type), + src0_type, + src1_type, + ret, + supported); #endif return false; } @@ -244,8 +257,8 @@ bool npu_device::init_device_lib() { } if (err != AEE_SUCCESS) { - LOG_ERROR("[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err, - device_lib_uri.c_str()); + LOG_ERROR( + "[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err, device_lib_uri.c_str()); _device_handle = 0; return false; } @@ -275,16 +288,26 @@ bool npu_device::supports_op(const ggml_tensor * op) { if (op->op != GGML_OP_NONE && op->op != GGML_OP_VIEW && op->op != GGML_OP_RESHAPE && op->op != GGML_OP_PERMUTE) { _supported_op++; - LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op), - ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load()); + LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n", + get_name(), + ggml_op_desc(op), + ggml_get_name(op), + op_desc, + _supported_op.load(), + _unsupported_op.load()); } return true; } _unsupported_op++; - LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op), - ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load()); + LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n", + get_name(), + ggml_op_desc(op), + ggml_get_name(op), + op_desc, + _supported_op.load(), + _unsupported_op.load()); return false; } #else diff --git a/ggml/src/ggml-qnn/npu/host/util.cpp b/ggml/src/ggml-qnn/npu/host/util.cpp index a07b4d0ed6..13a21c1f9e 100644 --- a/ggml/src/ggml-qnn/npu/host/util.cpp +++ b/ggml/src/ggml-qnn/npu/host/util.cpp @@ -35,6 +35,8 @@ enum npu_device_tensor_op op_to_npu_op(ggml_op op) { return NPU_OP_FLASH_ATTN; case GGML_OP_ROPE: return NPU_OP_ROPE; + case GGML_OP_GLU: + return NPU_OP_GLU; default: return NPU_OP_COUNT; } @@ -56,6 +58,8 @@ const char * get_npu_op_desc(enum npu_device_tensor_op op) { return ggml_op_name(GGML_OP_FLASH_ATTN_EXT); case NPU_OP_ROPE: return ggml_op_name(GGML_OP_ROPE); + case NPU_OP_GLU: + return ggml_op_name(GGML_OP_GLU); default: return "UNKNOWN"; } diff --git a/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl b/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl index 513b69d88a..1a9a4cb3a6 100644 --- a/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl +++ b/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl @@ -17,6 +17,7 @@ interface npu_device : remote_handle64{ typedef int64_t ne_type[DEVICE_TENSOR_MAX_DIMS]; typedef uint64_t nb_type[DEVICE_TENSOR_MAX_DIMS]; + typedef int32_t param_type[DEVICE_TENSOR_MAX_OP_PARAMS]; typedef uint64_t tensor_handle_t; typedef uint64_t graph_handle_t; @@ -50,9 +51,19 @@ interface npu_device : remote_handle64{ NPU_OP_RMS_NORM, NPU_OP_FLASH_ATTN, NPU_OP_ROPE, + NPU_OP_GLU, NPU_OP_COUNT }; + enum glu_op { + NPU_GLU_OP_REGLU, + NPU_GLU_OP_GEGLU, + NPU_GLU_OP_SWIGLU, + NPU_GLU_OP_GEGLU_ERF, + NPU_GLU_OP_GEGLU_QUICK, + NPU_GLU_OP_COUNT + }; + enum tensor_data_type { NPU_DATA_TYPE_F32, NPU_DATA_TYPE_F16, @@ -69,9 +80,14 @@ interface npu_device : remote_handle64{ tensor_data_type type; }; + struct tensor_op_spec { + tensor_op op; + param_type params; + }; + struct tensor_update_config { tensor_op op; - int32_t params[DEVICE_TENSOR_MAX_OP_PARAMS]; + param_type params; tensor_handle_t src_handles[DEVICE_TENSOR_MAX_SRC]; }; @@ -90,7 +106,7 @@ interface npu_device : remote_handle64{ ); AEEResult device_support_op( - in tensor_op op, + in tensor_op_spec op_spec, in tensor_spec dst, in sequence srcs, rout boolean is_supported