From af620a12f7a83ce74e1a94e34eaa88b99dc1127b Mon Sep 17 00:00:00 2001 From: nullname Date: Wed, 18 Jun 2025 10:32:08 +0800 Subject: [PATCH] feat: flash attention support for hexagon-npu (#45) * add flash attn op * expend src tensor size * add flash attn sources * add quantize row functions * make a separated file for vec_dot * wip * wip * refactor: rename quants.hpp includes and add vec_dot to type traits * add flash_attn impl * split vec_scale_f32 * move vec_reduction_qf32 to vec_ops * add vec_scale_f16 * opt * add vec_mad * implement vec_mad_f16 * opt * add op template * opt * add align version * enable flash attn * wip * log print improve * add profiler log * wip * wip * add multi sub proc perf tracker * increase log buffer * remove sub prov pcycle * wip * wip * add prefetch for vec_dot * wip * wip * opt f16 vec dot * opt f16 vecdot * reuse vec_dot_product_impl in vec dot f32 * small opt to unblock pipeline * opt on aligned address wip * Revert "opt on aligned address" This reverts commit 27be1eb61a7d29d2f5fa6f90383e1b5d7fdf9b6a. * add profiler log at thread_pool * wip * invalidate all... * Reapply "opt on aligned address" This reverts commit f075a4c4586e32b7e5819c1fe7f9b6ed218b1767. * add is_constant for tensor config * disable align tensor opt in mul_mat * wip * wip * vec_scale_impl: unrolling the loop * wip * wip * replace reinterpret_cast with direct pointer access for write/read buffers * add fetch * wip * wip * wip * add log * check tensor shape at flash_attn * wip * wip * fix: update tensor type handling in flash_attn_impl * wip * fix: align cache size * fix: qf16->hf * fix: swap order of elements in vector combine for correct scaling * fix: opt f16 scale and mad * fix leftover fetch * wip * load into vector pair * opt cache size calculation in flash_attn_impl * refactoring: hold vtcm at thread local object * wip * add profiler log * mark tensors as modified * restrict tensor invalidation to the first thread in compute_impl * Revert "restrict tensor invalidation to the first thread in compute_impl" This reverts commit 0a8ff2b1bcf366097c16d7437c091382eacbef8b. * invalidate last tensor in compute_impl * invalidate last tensor in compute function * wip * refactor dequantize_row_q4_0 to simplify vector alignment * wip * refactoring: move VTCM quota calculation to thread pool * wip * fix: correct condition check for HEXAGON_SDK_ROOT existence * wip * wip * wip * wip * fix: update condition checks match the naming * fix: improve tensor handling checks and logging in graph and operation implementations * wip --- ggml/src/ggml-qnn/npu/CMakeLists.txt | 2 + ggml/src/ggml-qnn/npu/device/device.cpp | 28 +- ggml/src/ggml-qnn/npu/device/graph.cpp | 29 +- ggml/src/ggml-qnn/npu/device/graph.hpp | 6 +- .../src/ggml-qnn/npu/device/op_flash_attn.cpp | 321 ++++++++++++ .../src/ggml-qnn/npu/device/op_flash_attn.hpp | 11 + ggml/src/ggml-qnn/npu/device/op_impl.cpp | 233 ++++----- ggml/src/ggml-qnn/npu/device/op_impl.hpp | 4 +- ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp | 226 ++------- ggml/src/ggml-qnn/npu/device/op_mul_mat.hpp | 60 +-- ggml/src/ggml-qnn/npu/device/op_types.hpp | 64 +-- ggml/src/ggml-qnn/npu/device/quants.cpp | 213 -------- ggml/src/ggml-qnn/npu/device/tensor.hpp | 34 +- ggml/src/ggml-qnn/npu/device/thread_pool.hpp | 126 +++-- ggml/src/ggml-qnn/npu/device/type_traits.cpp | 467 ++++++++++++++++++ .../device/{quants.hpp => type_traits.hpp} | 37 +- ggml/src/ggml-qnn/npu/device/util.hpp | 139 ++++-- ggml/src/ggml-qnn/npu/device/vec_ops.cpp | 156 ++++++ ggml/src/ggml-qnn/npu/device/vec_ops.hpp | 274 ++++++++++ ggml/src/ggml-qnn/npu/host/buffer.cpp | 4 +- ggml/src/ggml-qnn/npu/host/graph.cpp | 16 +- ggml/src/ggml-qnn/npu/host/host_device.cpp | 59 +-- ggml/src/ggml-qnn/npu/host/tensor.hpp | 52 +- ggml/src/ggml-qnn/npu/host/util.cpp | 78 ++- ggml/src/ggml-qnn/npu/host/util.hpp | 2 + ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl | 16 +- ggml/src/ggml-qnn/shared/CMakeLists.txt | 2 + 27 files changed, 1859 insertions(+), 800 deletions(-) create mode 100644 ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp create mode 100644 ggml/src/ggml-qnn/npu/device/op_flash_attn.hpp delete mode 100644 ggml/src/ggml-qnn/npu/device/quants.cpp create mode 100644 ggml/src/ggml-qnn/npu/device/type_traits.cpp rename ggml/src/ggml-qnn/npu/device/{quants.hpp => type_traits.hpp} (65%) create mode 100644 ggml/src/ggml-qnn/npu/device/vec_ops.cpp create mode 100644 ggml/src/ggml-qnn/npu/device/vec_ops.hpp diff --git a/ggml/src/ggml-qnn/npu/CMakeLists.txt b/ggml/src/ggml-qnn/npu/CMakeLists.txt index 5e1281c3d5..e8ce255fec 100644 --- a/ggml/src/ggml-qnn/npu/CMakeLists.txt +++ b/ggml/src/ggml-qnn/npu/CMakeLists.txt @@ -3,6 +3,8 @@ cmake_policy(SET CMP0115 OLD) if(DEFINED ENV{HEXAGON_SDK_ROOT}) set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT}) + message("HEXAGON_SDK_ROOT (from environment): ${HEXAGON_SDK_ROOT}") +elseif(DEFINED HEXAGON_SDK_ROOT) message("HEXAGON_SDK_ROOT: ${HEXAGON_SDK_ROOT}") else() message(FATAL_ERROR "HEXAGON_SDK_ROOT not defined") diff --git a/ggml/src/ggml-qnn/npu/device/device.cpp b/ggml/src/ggml-qnn/npu/device/device.cpp index 8a10e9e752..ff2819bae6 100644 --- a/ggml/src/ggml-qnn/npu/device/device.cpp +++ b/ggml/src/ggml-qnn/npu/device/device.cpp @@ -9,10 +9,10 @@ #include "graph.hpp" #include "hexagon_npu.h" #include "op_impl.hpp" -#include "quants.hpp" #include "remote.h" #include "tensor.hpp" #include "thread_pool.hpp" +#include "type_traits.hpp" #include "util.hpp" namespace { @@ -124,21 +124,20 @@ int npu_device_close(remote_handle64 h) { AEEResult npu_device_device_get_alignment(remote_handle64 _h, uint32_t * alignment) { NPU_UNUSED(_h); - *alignment = sizeof(HVX_Vector); + *alignment = sizeof(HVX_VectorPair); return AEE_SUCCESS; } -AEEResult npu_device_device_support_op(remote_handle64 _h, const npu_device_tensor_spec * src0, - const npu_device_tensor_spec * src1, const npu_device_tensor_spec * dst, - npu_device_tensor_op op, boolean * is_supported) { +AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op op, const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, int srcsLen, boolean * is_supported) { NPU_UNUSED(_h); - if (!src0 || !src1 || !dst || !is_supported) { + if (!srcs || srcsLen <= 0 || !dst || !is_supported) { DEVICE_LOG_ERROR("npu_device_device_support_op: Invalid arguments"); return AEE_EINVARGS; } - *is_supported = hexagon::support_op(*src0, *src1, *dst, op); + *is_supported = hexagon::support_op(op, dst, srcs, srcsLen); return AEE_SUCCESS; } @@ -208,19 +207,20 @@ AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h, npu_device_ int tensor_paramsLen) { NPU_UNUSED(_h); auto * graph = graph_from_handle(graph_handle); - if (!graph || !tensor_handles || tensor_handlesLen <= 0 || !tensor_params || - tensor_handlesLen != tensor_paramsLen) { + if (!graph || tensor_handlesLen != tensor_paramsLen || tensor_handlesLen < 0) { return AEE_EINVHANDLE; } - graph->set_tensor(tensor_handles, tensor_handlesLen); - for (int i = 0; i < tensor_handlesLen; ++i) { - auto * tensor = tensor_from_handle(tensor_handles[i]); - if (tensor) { - tensor->update_config(tensor_params[i]); + if (tensor_params && tensor_handles) { + for (int i = 0; i < tensor_handlesLen; ++i) { + auto * tensor = tensor_from_handle(tensor_handles[i]); + if (tensor) { + tensor->update_config(tensor_params[i]); + } } } + graph->set_tensor(tensor_handles, tensor_handlesLen); return AEE_SUCCESS; } diff --git a/ggml/src/ggml-qnn/npu/device/graph.cpp b/ggml/src/ggml-qnn/npu/device/graph.cpp index c9cad77232..5bc14a0aca 100644 --- a/ggml/src/ggml-qnn/npu/device/graph.cpp +++ b/ggml/src/ggml-qnn/npu/device/graph.cpp @@ -10,8 +10,7 @@ namespace hexagon { graph::graph() noexcept { - _vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size(); // TODO: move to device init? - DEVICE_LOG_DEBUG("graph(%p) created: vtcm quota size: %zu\n", (void *) this, _vtcm_quota_size); + DEVICE_LOG_DEBUG("graph(%p) created\n", (void *) this); } graph::~graph() noexcept { @@ -20,9 +19,10 @@ graph::~graph() noexcept { } void graph::set_tensor(const npu_device_tensor_handle_t * tensors, int tensor_count) { - if (tensor_count <= 0) { + if (tensor_count <= 0 || !tensors) { _tensors.reset(); _tensor_count = 0; + DEVICE_LOG_DEBUG("graph(%p) set_tensor: no tensors to set\n", (void *) this); return; } @@ -50,21 +50,27 @@ bool graph::compute(default_thread_pool * thread_pool, const float * f16_to_f32_ DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]compute", (void *) this); _f16_to_f32_table = f16_to_f32_table; if (thread_pool) { - thread_pool->sync_execute(reinterpret_cast(&graph::thread_pool_task), this); + thread_pool->sync_execute(&graph::thread_pool_task, this); } else { - compute_impl(nullptr, 0, 1); + default_thread_pool::thread_params param = { + 0, 1, nullptr, hexagon::vtcm_mem::get_avail_block_size() + }; // TODO: should have a better way to initialize thread_params + + compute_impl(nullptr, ¶m); } + _tensors[_tensor_count - 1]->invalidate(); _f16_to_f32_table = nullptr; return true; } -void graph::thread_pool_task(default_thread_pool * pool, size_t thread_idx, size_t thread_count, graph * graph) { - graph->compute_impl(pool, thread_idx, thread_count); +void graph::thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params, + void * graph) { + reinterpret_cast(graph)->compute_impl(pool, thread_params); } -void graph::compute_impl(default_thread_pool * pool, size_t thread_idx, size_t thread_count) { - hexagon::compute_params params = { thread_idx, thread_count, _vtcm_quota_size / thread_count, _f16_to_f32_table }; +void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params) { + hexagon::compute_params params = { thread_params, _f16_to_f32_table }; for (size_t i = 0; i < _tensor_count; ++i) { auto * dst = _tensors[i]; @@ -78,13 +84,12 @@ void graph::compute_impl(default_thread_pool * pool, size_t thread_idx, size_t t DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d compute failed\n", (void *) this, i, op); } - DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu", (void *) this, thread_idx); - const bool should_sync = requires_thread_barrier(op); if (pool && should_sync && i < _tensor_count - 1) { + DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this, + params.get_thread_index(), i, _tensor_count); pool->sync_thread(); } - dst->invalidate(); } } diff --git a/ggml/src/ggml-qnn/npu/device/graph.hpp b/ggml/src/ggml-qnn/npu/device/graph.hpp index c6b68c4eea..36cf5bfc5c 100644 --- a/ggml/src/ggml-qnn/npu/device/graph.hpp +++ b/ggml/src/ggml-qnn/npu/device/graph.hpp @@ -20,12 +20,12 @@ class graph { bool compute(default_thread_pool * thread_pool, const float * f16_to_f32_table); private: - static void thread_pool_task(default_thread_pool * pool, size_t thread_idx, size_t thread_count, graph * graph); - void compute_impl(default_thread_pool * pool, size_t thread_idx, size_t thread_count); + static void thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params, + void * graph); + void compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params); std::unique_ptr _tensors; size_t _tensor_count = 0; - size_t _vtcm_quota_size = 0; const float * _f16_to_f32_table = nullptr; DISABLE_COPY_AND_MOVE(graph); diff --git a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp new file mode 100644 index 0000000000..0c1ac778ba --- /dev/null +++ b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp @@ -0,0 +1,321 @@ + +#include "op_flash_attn.hpp" + +#include "type_traits.hpp" +#include "util.hpp" +#include "vec_ops.hpp" + +namespace { + +// TODO: use a more efficient conversion +inline float f16_to_f32(const npu_device_fp16_t src) { + return reinterpret_cast(src); +} + +// From: ggml/src/ggml-cpu/ops.cpp +void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k, + const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) { + static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count"); + + float scale = out->get_op_param(0); + const float max_bias = out->get_op_param(1); + const float logit_softcap = out->get_op_param(2); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + // broadcast factors + const int64_t rk2 = q->get_ne(2) / k->get_ne(2); + const int64_t rk3 = q->get_ne(3) / k->get_ne(3); + const int64_t rv2 = q->get_ne(2) / v->get_ne(2); + const int64_t rv3 = q->get_ne(3) / v->get_ne(3); + + const uint32_t n_head = q->get_ne(2); + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const auto q_to_vec_dot = hexagon::get_type_traits(k->get_type()).from_float; // TODO: fix this + const auto kq_vec_dot = hexagon::get_type_traits(k->get_type()).vec_dot; + const auto v_to_float = hexagon::get_type_traits(v->get_type()).to_float; + if (!q_to_vec_dot || !kq_vec_dot) { + DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n"); + return; + } + + const int64_t total_rows = q->get_ne(1) * q->get_ne(2) * q->get_ne(3); // total number of rows in Q + const auto start_end_row = params->get_work_slice(total_rows); // work slice for this thread + + const auto DK = k->get_ne(0); + const auto DV = v->get_ne(0); + const auto row_bytes_q = q->get_ne(0) * hexagon::get_type_traits(q->get_type()).type_size; + const auto row_bytes_k = DK * hexagon::get_type_traits(k->get_type()).type_size; + const auto row_bytes_v = DV * hexagon::get_type_traits(v->get_type()).type_size; + + constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float); + const auto aligned_dk = (DK + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector; + const auto aligned_dv = (DV + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector; + size_t total_cache_size = sizeof(float) * (aligned_dk + 2 * aligned_dv); + auto * cache_ptr = params->get_vtcm_cache(total_cache_size); + if (!cache_ptr) { + DEVICE_LOG_ERROR("Failed to allocate VTCM cache for flash_attn: %zu bytes\n", total_cache_size); + return; + } + + // loop over n_batch and n_head + const auto rows_per_batch = q->get_ne(2) * q->get_ne(1); + const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1); + const bool is_v_f16 = + v->get_type() == NPU_DATA_TYPE_F16; // check if V is in FP16 format, otherwise it is in FP32 format + uint8_t * dst_ptr = out->get_write_buffer(); + if (!dst_ptr) { + DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out, + hexagon::get_type_name(out->get_type())); + return; + } + + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(out, params->get_thread_index(), flash_attn); + const uint8_t * q_ptr = q->get_read_buffer(); + const uint8_t * k_ptr = k->get_read_buffer(); + const uint8_t * v_ptr = v->get_read_buffer(); + const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr; + for (auto ir = start_end_row.first; ir < start_end_row.second; ++ir) { + // q indices + const auto iq3 = ir / rows_per_batch; + const auto iq2 = (ir - iq3 * rows_per_batch) / q->get_ne(1); + const auto iq1 = (ir - iq3 * rows_per_batch - iq2 * q->get_ne(1)); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + float * VKQ32 = reinterpret_cast(cache_ptr); // FP32 VKQ accumulator + float * V32 = VKQ32 + aligned_dv; // (temporary) FP32 V buffer + auto * VKQ16 = reinterpret_cast(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator + auto * Q_q = reinterpret_cast( + VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16 + + if (is_v_f16) { + memset(VKQ16, 0, DV * sizeof(npu_device_fp16_t)); + } else { + memset(VKQ32, 0, DV * sizeof(float)); + } + + const npu_device_fp16_t * mp = + mask_ptr ? reinterpret_cast(mask_ptr + iq1 * mask->get_nb(1)) : nullptr; + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3)); + if (iq1 < q->get_ne(1) - 1) { + hexagon::l2fetch_row(q_data + q->get_nb(1), row_bytes_q); + } + + q_to_vec_dot(reinterpret_cast(q_data), Q_q, DK, params->f16_to_f32_table); + + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < k->get_ne(1); ++ic) { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 0, loop); + float mv = mp ? (slope * f16_to_f32(mp[ic])) : 0.0f; + if (mv == -INFINITY) { + continue; + } + + float s = 0.f; + { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 1, kq_dot); + const auto * k_data = k_ptr + (ic * k->get_nb(1) + ik2 * k->get_nb(2) + ik3 * k->get_nb(3)); + if (ic < k->get_ne(1) - 1) { + hexagon::l2fetch_row(k_data + k->get_nb(1), row_bytes_k); + } + + s = kq_vec_dot(k_data, Q_q, DK); // KQ value + s = s * scale; // scale KQ value + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); // TODO: vectorize this? + } + + s += mv; // apply mask + } + + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + const auto * v_data = v_ptr + (ic * v->get_nb(1) + iv2 * v->get_nb(2) + iv3 * v->get_nb(3)); + if (ic < v->get_ne(1)) { + hexagon::l2fetch_row(v_data, row_bytes_v); + } + + if (is_v_f16) { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + hexagon::vec_scale_f16(VKQ16, ms, VKQ16, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + // V += v*expf(s - M) + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 2, mad); + hexagon::vec_mad_f16(reinterpret_cast(v_data), vs, VKQ16, DV); + } else { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + hexagon::vec_scale_f32(VKQ32, ms, VKQ32, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + // V += v*expf(s - M) + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 2, mad); + if (v_to_float) { + v_to_float(v_data, V32, DV, params->f16_to_f32_table); + hexagon::vec_mad_f32(V32, vs, VKQ32, DV); + } else { + // V is F32 + hexagon::vec_mad_f32(reinterpret_cast(v_data), vs, VKQ32, DV); + } + } + + S = S * ms + vs; // scale and increment sum with partial sum + } + + if (is_v_f16) { + // TODO: use a more efficient conversion + for (int64_t d = 0; d < DV; ++d) { + VKQ32[d] = f16_to_f32(VKQ16[d]); + } + } + + // V /= S + const float S_inv = 1.0f / S; + hexagon::vec_scale_f32(VKQ32, S_inv, VKQ32, DV); + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // permute(0, 2, 1, 3) + memcpy(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1), VKQ32, out->get_nb(1)); + } + + out->release_write_buffer(); // mark the output tensor as modified +} + +} // namespace + +namespace hexagon { + +bool flash_attn_f32(tensor * out, compute_params * params) { + if (!out || !params) { + DEVICE_LOG_DEBUG("invalid out or params\n"); + return false; + } + + const auto * q = out->get_src(0); + const auto * k = out->get_src(1); + const auto * v = out->get_src(2); + const auto * mask = out->get_src(3); + if (!q || !k || !v || !mask) { + DEVICE_LOG_DEBUG("invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v, + (void *) mask); + return false; + } + + flash_attn_impl(out, q, k, v, mask, params); + return true; +} + +bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, size_t src_len) { + if (op != NPU_OP_FLASH_ATTN) { + DEVICE_LOG_DEBUG("op is not NPU_OP_FLASH_ATTN: %d\n", op); + return false; + } + + if (!dst || !srcs || src_len < 4) { + DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", op_get_name(op)); + return false; + } + + if (dst->type != NPU_DATA_TYPE_F32) { + DEVICE_LOG_DEBUG("[%s]dst type is not F32: %s\n", op_get_name(op), get_type_name(dst->type)); + return false; + } + + const auto * q = &srcs[0]; + if (q->type != NPU_DATA_TYPE_F32) { + DEVICE_LOG_DEBUG("[%s]q type is not F32: %s\n", op_get_name(op), get_type_name(q->type)); + return false; + } + + const auto * k = &srcs[1]; + if (k->type != NPU_DATA_TYPE_F16) { // TODO: support more k types + DEVICE_LOG_DEBUG("[%s]k type is not F16: %s\n", op_get_name(op), get_type_name(k->type)); + return false; + } + + const auto * v = &srcs[2]; + if (v->type != k->type) { // TODO: support more v types + DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n", op_get_name(op), get_type_name(v->type), + get_type_name(k->type)); + return false; + } + + const auto * mask = &srcs[3]; + if (mask->type != NPU_DATA_TYPE_F16) { + DEVICE_LOG_DEBUG("[%s]mask type is not F16: %s\n", op_get_name(op), get_type_name(mask->type)); + return false; + } + + if (dst->ne[0] != v->ne[0] || dst->ne[2] != q->ne[1]) { + DEVICE_LOG_DEBUG( + "[%s]dst shape does not match q and v: dst ne: %ld, %ld, %ld, %ld, q ne: %ld, %ld, %ld, %ld, " + "v ne: %ld, %ld, %ld, %ld\n", + op_get_name(op), dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3], + v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + return false; + } + + if (is_transposed_or_permuted(dst->nb)) { + DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n", op_get_name(op), + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]); + return false; + } + + if (q->ne[0] != k->ne[0]) { + DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n", + op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2], + k->ne[3]); + return false; + } + + return true; +} + +} // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op_flash_attn.hpp b/ggml/src/ggml-qnn/npu/device/op_flash_attn.hpp new file mode 100644 index 0000000000..63d6d09d54 --- /dev/null +++ b/ggml/src/ggml-qnn/npu/device/op_flash_attn.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "op_types.hpp" + +namespace hexagon { + +bool flash_attn_f32(tensor * out, compute_params * params); +bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, size_t src_len); + +} // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op_impl.cpp b/ggml/src/ggml-qnn/npu/device/op_impl.cpp index 777072024a..4d271a4899 100644 --- a/ggml/src/ggml-qnn/npu/device/op_impl.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_impl.cpp @@ -2,13 +2,12 @@ #include "op_impl.hpp" -#include -#include - #include +#include "op_flash_attn.hpp" #include "op_mul_mat.hpp" -#include "quants.hpp" +#include "type_traits.hpp" +#include "vec_ops.hpp" namespace { @@ -16,12 +15,12 @@ template inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) { constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData); - HVX_Vector * iptr0 = ((HVX_Vector *) src0); - HVX_Vector * iptr0_end = ((HVX_Vector *) src0) + (count / kElementsPerVector); - HVX_Vector * iptr1 = ((HVX_Vector *) src1); - HVX_Vector * optr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned - HVX_Vector prev0 = *iptr0++; - HVX_Vector prev1 = *iptr1++; + HVX_Vector * iptr0 = ((HVX_Vector *) src0); + HVX_Vector * const iptr0_end = ((HVX_Vector *) src0) + (count / kElementsPerVector); + HVX_Vector * iptr1 = ((HVX_Vector *) src1); + HVX_Vector * optr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned + HVX_Vector prev0 = *iptr0++; + HVX_Vector prev1 = *iptr1++; while (iptr0 < iptr0_end) { HVX_Vector curr0 = *iptr0++; @@ -33,25 +32,25 @@ inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count prev1 = curr1; } + const size_t leftover = count % kElementsPerVector; if ((iptr0_end - ((HVX_Vector *) src0)) > 0) { // handle the last vector // see also: // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c - bool iptr0_aligned = hexagon::is_addr_aligned(iptr0); - HVX_Vector curr0 = iptr0_aligned ? prev0 : *iptr0; - iptr0 = iptr0_aligned ? iptr0 : iptr0 + 1; - bool iptr1_aligned = hexagon::is_addr_aligned(iptr1); - HVX_Vector curr1 = iptr1_aligned ? prev1 : *iptr1; - iptr1 = iptr1_aligned ? iptr1 : iptr1 + 1; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - *optr++ = _OpIntrinsic(s0, s1); - prev0 = curr0; - prev1 = curr1; + bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(iptr0); + bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(iptr1); + HVX_Vector curr0 = should_fetch_src0 ? *iptr0 : prev0; + HVX_Vector curr1 = should_fetch_src1 ? *iptr1 : prev1; + iptr0 += should_fetch_src0 ? 1 : 0; + iptr1 += should_fetch_src1 ? 1 : 0; + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + *optr++ = _OpIntrinsic(s0, s1); + prev0 = curr0; + prev1 = curr1; } - const size_t leftover = count % kElementsPerVector; const size_t leftover_bytes = leftover * sizeof(_TyData); if (leftover > 0) { // handle the leftover elements @@ -136,18 +135,23 @@ template bool element_wise_op(hexagon::tensor * out, hexagon::co return false; } - const auto * src0_ptr = reinterpret_cast(src0->get_read_buffer()); - const auto * src1_ptr = reinterpret_cast(src1->get_read_buffer()); - auto * dst_ptr = reinterpret_cast(out->get_write_buffer()); - auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1); - const auto rows_per_cube = out->get_ne(2) * out->get_ne(1); - const auto start_end = hexagon::get_thread_work_slice(total_rows, params->tidx, params->tcnt); + uint8_t * dst_ptr = out->get_write_buffer(); + if (!dst_ptr) { + DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out, + hexagon::get_type_name(out->get_type())); + return false; + } + const uint8_t * src0_ptr = src0->get_read_buffer(); + const uint8_t * src1_ptr = src1->get_read_buffer(); + auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1); + const auto rows_per_cube = out->get_ne(2) * out->get_ne(1); + const auto start_end = params->get_work_slice(total_rows); if (start_end.first >= start_end.second) { return true; } - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->tidx); + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index()); const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type); for (int64_t ir = start_end.first; ir < start_end.second; ++ir) { @@ -171,6 +175,7 @@ template bool element_wise_op(hexagon::tensor * out, hexagon::co static_cast(out->get_ne(0)), reinterpret_cast(dst_row)); } + out->release_write_buffer(); // mark the output tensor as modified return true; } @@ -184,27 +189,36 @@ bool is_same_shape(const npu_device_tensor_spec & src, const npu_device_tensor_s return true; } -bool is_element_wise_op_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1, - const npu_device_tensor_spec & dst, npu_device_tensor_op op) { +bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, size_t src_len) { if (op != NPU_OP_ADD && op != NPU_OP_SUB && op != NPU_OP_MUL) { DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op)); return false; } - if (dst.type != src0.type || dst.type != src1.type) { - DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op), - hexagon::get_type_name(src0.type), hexagon::get_type_name(dst.type)); + if (!dst || !srcs || src_len < 2) { + DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op)); return false; } - if (dst.type != NPU_DATA_TYPE_F32 && dst.type != NPU_DATA_TYPE_F16) { - DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst.type)); + const auto & src0 = srcs[0]; + const auto & src1 = srcs[1]; + if (dst->type != src0.type || dst->type != src1.type) { + DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op), + hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type)); + return false; + } + + if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) { + DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), + hexagon::get_type_name(dst->type)); return false; } // TODO: fix FP16 add/sub - if (dst.type == NPU_DATA_TYPE_F16 && op != NPU_OP_MUL) { - DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst.type)); + if (dst->type == NPU_DATA_TYPE_F16 && op != NPU_OP_MUL) { + DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), + hexagon::get_type_name(dst->type)); return false; } @@ -214,7 +228,7 @@ bool is_element_wise_op_supported(const npu_device_tensor_spec & src0, const npu return false; } - if (!is_same_shape(src0, dst)) { + if (!is_same_shape(src0, *dst)) { DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op)); return false; } @@ -225,10 +239,10 @@ bool is_element_wise_op_supported(const npu_device_tensor_spec & src0, const npu void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) { constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(float); - HVX_Vector * src_vec_ptr = ((HVX_Vector *) src); - HVX_Vector * src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector); - HVX_Vector prev = *src_vec_ptr++; - HVX_Vector sum = Q6_V_vzero(); + HVX_Vector * src_vec_ptr = ((HVX_Vector *) src); + HVX_Vector * const src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector); + HVX_Vector prev = *src_vec_ptr++; + HVX_Vector sum = Q6_V_vzero(); while (src_vec_ptr < src_vec_end) { HVX_Vector curr = *src_vec_ptr++; HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); @@ -236,17 +250,17 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) { prev = curr; } + const size_t leftover = count % kElementsPerVector; if ((src_vec_end - ((HVX_Vector *) src)) > 0) { // handle the last vector - bool src_ptr_aligned = hexagon::is_addr_aligned(src_vec_ptr); - HVX_Vector curr = src_ptr_aligned ? prev : *src_vec_ptr; - src_vec_ptr = src_ptr_aligned ? src_vec_ptr : src_vec_ptr + 1; - HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(s0, s0)); - prev = curr; + bool should_fetch_src = leftover != 0 || !hexagon::is_addr_aligned(src_vec_ptr); + HVX_Vector curr = should_fetch_src ? *src_vec_ptr : prev; + src_vec_ptr += should_fetch_src ? 1 : 0; + HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); + sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(s0, s0)); + prev = curr; } - const size_t leftover = count % kElementsPerVector; const size_t leftover_bytes = leftover * sizeof(float); if (leftover > 0) { // handle the leftover elements @@ -257,37 +271,9 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) { Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes)); } - const float mean = hexagon::vec_reduction_f32(sum) / count; // TODO: figure out how to do division in vector - const float scale = 1.0f / sqrtf(mean + eps); // TODO: use buildin blas sqrtf? - - HVX_Vector scale_vec = Q6_V_vsplat_R(reinterpret_cast(scale)); - src_vec_ptr = ((HVX_Vector *) src); - prev = *src_vec_ptr++; - HVX_Vector * dst_vec_ptr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned - while (src_vec_ptr < src_vec_end) { - HVX_Vector curr = *src_vec_ptr++; - HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); - *dst_vec_ptr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(s0, scale_vec)); - prev = curr; - } - - if ((src_vec_end - ((HVX_Vector *) src)) > 0) { - // handle the last vector - bool src_ptr_aligned = hexagon::is_addr_aligned(src_vec_ptr); - HVX_Vector curr = src_ptr_aligned ? prev : *src_vec_ptr; - src_vec_ptr = src_ptr_aligned ? src_vec_ptr : src_vec_ptr + 1; - HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); - *dst_vec_ptr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(s0, scale_vec)); - prev = curr; - } - - if (leftover > 0) { - // handle the leftover elements - HVX_Vector curr = - (leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev; - curr = Q6_V_valign_VVR(curr, prev, (size_t) src); - q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(curr, scale_vec))); - } + const float mean = hexagon::vec_reduction_qf32_f32(sum) / count; // TODO: figure out how to do division in vector + const float scale = 1.0f / sqrtf(mean + eps); // TODO: use buildin blas sqrtf? + hexagon::vec_scale_f32(src, scale, dst, count); } // TODO: merge with element_wise_op? @@ -305,16 +291,22 @@ template bool unary_op(hexagon::tensor * out, hexagon::compute_p return true; // skip if no src } - const auto * src0_ptr = reinterpret_cast(src0->get_read_buffer()); - auto * dst_ptr = reinterpret_cast(out->get_write_buffer()); + auto * dst_ptr = out->get_write_buffer(); + if (!dst_ptr) { + DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out, + hexagon::get_type_name(out->get_type())); + return false; + } + + const auto * src0_ptr = src0->get_read_buffer(); auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1); const auto rows_per_cube = out->get_ne(2) * out->get_ne(1); - const auto start_end = hexagon::get_thread_work_slice(total_rows, params->tidx, params->tcnt); + const auto start_end = params->get_work_slice(total_rows); if (start_end.first >= start_end.second) { return true; } - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->tidx); + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index()); const auto param = out->get_op_param(0); const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type); @@ -333,28 +325,36 @@ template bool unary_op(hexagon::tensor * out, hexagon::compute_p reinterpret_cast(dst_row)); } + out->release_write_buffer(); // mark the output tensor as modified return true; } -bool is_unary_op_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1, - const npu_device_tensor_spec & dst, npu_device_tensor_op op) { +bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, size_t src_len) { if (op != NPU_OP_RMS_NORM) { DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op)); return false; } - if (dst.type != src0.type) { + if (!dst || !srcs || src_len < 1) { + DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op)); + return false; + } + + const auto & src0 = srcs[0]; + if (dst->type != src0.type) { DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op), - hexagon::get_type_name(src0.type), hexagon::get_type_name(dst.type)); + hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type)); return false; } - if (dst.type != NPU_DATA_TYPE_F32) { - DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst.type)); + if (dst->type != NPU_DATA_TYPE_F32) { + DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), + hexagon::get_type_name(dst->type)); return false; } - if (!is_same_shape(src0, dst)) { + if (!is_same_shape(src0, *dst)) { DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op)); return false; } @@ -371,40 +371,47 @@ struct op_capabilities { constexpr const op_capabilities kOpCapabilities[] = { { - NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported, + NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported, { hexagon::mul_mat_f32, // NPU_DATA_TYPE_F32 nullptr, // NPU_DATA_TYPE_F16 - }, true, - }, + }, true, // requires_thread_barrier + }, { - NPU_OP_ADD, is_element_wise_op_supported, + NPU_OP_ADD, is_element_wise_op_supported, { element_wise_op>, // NPU_DATA_TYPE_F32 element_wise_op>, // NPU_DATA_TYPE_F16 - }, false, - }, + }, false, // requires_thread_barrier + }, { - NPU_OP_SUB, is_element_wise_op_supported, + NPU_OP_SUB, is_element_wise_op_supported, { element_wise_op>, // NPU_DATA_TYPE_F32 element_wise_op>, // NPU_DATA_TYPE_F16 - }, false, - }, + }, false, // requires_thread_barrier + }, { - NPU_OP_MUL, is_element_wise_op_supported, + NPU_OP_MUL, is_element_wise_op_supported, { element_wise_op>, // NPU_DATA_TYPE_F32 element_wise_op>, // NPU_DATA_TYPE_F16 - }, false, - }, + }, false, // requires_thread_barrier + }, { - NPU_OP_RMS_NORM, is_unary_op_supported, + NPU_OP_RMS_NORM, is_unary_op_supported, { unary_op, // NPU_DATA_TYPE_F32 nullptr, // NPU_DATA_TYPE_F16 - }, false, - }, + }, false, // requires_thread_barrier + }, + { + NPU_OP_FLASH_ATTN,hexagon::is_flash_attn_supported, + { + hexagon::flash_attn_f32, // NPU_DATA_TYPE_F32 + nullptr, // NPU_DATA_TYPE_F16 + }, true, // requires_thread_barrier + }, }; static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32, @@ -415,6 +422,8 @@ static_assert(kOpCapabilities[NPU_OP_MUL_MAT].op == NPU_OP_MUL_MAT, "kOpArray[NP static_assert(kOpCapabilities[NPU_OP_MUL].op == NPU_OP_MUL, "kOpArray[NPU_OP_MUL].op != NPU_OP_MUL"); static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM, "kOpArray[NPU_OP_RMS_NORM].op != NPU_OP_RMS_NORM"); +static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN, + "kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN"); hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) { if (op >= NPU_OP_COUNT) { @@ -440,16 +449,16 @@ bool requires_thread_barrier(npu_device_tensor_op op) { return kOpCapabilities[op].requires_thread_barrier; } -bool support_op(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1, - const npu_device_tensor_spec & dst, npu_device_tensor_op op) { - if (get_compute_func_impl(op, dst.type) == nullptr) { +bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, + size_t src_len) { + if (get_compute_func_impl(op, dst->type) == nullptr) { DEVICE_LOG_ERROR("[%s]unsupported, get_compute_func failed\n", op_get_name(op)); return false; } auto is_supported_func = kOpCapabilities[op].is_supported; - if (!is_supported_func || !is_supported_func(src0, src1, dst, op)) { - DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func failed\n", op_get_name(op)); + if (!is_supported_func || !is_supported_func(op, dst, srcs, src_len)) { + DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op)); return false; } diff --git a/ggml/src/ggml-qnn/npu/device/op_impl.hpp b/ggml/src/ggml-qnn/npu/device/op_impl.hpp index 9b75ec6d47..709d493428 100644 --- a/ggml/src/ggml-qnn/npu/device/op_impl.hpp +++ b/ggml/src/ggml-qnn/npu/device/op_impl.hpp @@ -8,7 +8,7 @@ compute_func_type get_compute_func(tensor * dst); bool requires_thread_barrier(npu_device_tensor_op op); -bool support_op(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1, - const npu_device_tensor_spec & dst, npu_device_tensor_op op); +bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs, + size_t src_len); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp b/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp index 6087673ac6..449f0edee1 100644 --- a/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp @@ -1,167 +1,12 @@ #include "op_mul_mat.hpp" -#include - -#include "quants.hpp" #include "thread_pool.hpp" // TODO: remove this dependency +#include "type_traits.hpp" +#include "vec_ops.hpp" #include "vtcm_mem.hpp" namespace { -inline float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) { - constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(float); - - HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); - HVX_Vector * src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector; - HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); - HVX_Vector prev0 = *src0_vec_ptr++; - HVX_Vector prev1 = *src1_vec_ptr++; - HVX_Vector sum = Q6_V_vzero(); - - while (src0_vec_ptr_end - src0_vec_ptr > 1) { - HVX_Vector curr0_lo = src0_vec_ptr[0]; - HVX_Vector curr0_hi = src0_vec_ptr[1]; - HVX_Vector curr1_lo = src1_vec_ptr[0]; - HVX_Vector curr1_hi = src1_vec_ptr[1]; - - HVX_Vector l0 = Q6_V_valign_VVR(curr0_lo, prev0, (size_t) src0); - HVX_Vector l1 = Q6_V_valign_VVR(curr1_lo, prev1, (size_t) src1); - HVX_Vector h0 = Q6_V_valign_VVR(curr0_hi, curr0_lo, (size_t) src0); - HVX_Vector h1 = Q6_V_valign_VVR(curr1_hi, curr1_lo, (size_t) src1); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_Vqf32_vmpy_VsfVsf(l0, l1), sum); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_Vqf32_vmpy_VsfVsf(h0, h1), sum); - - prev0 = curr0_hi; - prev1 = curr1_hi; - src0_vec_ptr += 2; - src1_vec_ptr += 2; - } - - if (src0_vec_ptr_end - src0_vec_ptr > 0) { - HVX_Vector curr0 = *src0_vec_ptr++; - HVX_Vector curr1 = *src1_vec_ptr++; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_Vqf32_vmpy_VsfVsf(s0, s1), sum); - prev0 = curr0; - prev1 = curr1; - } - - if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) { - // handle the last vector - // see also: - // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 - // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c - bool iptr0_aligned = hexagon::is_addr_aligned(src0_vec_ptr); - HVX_Vector curr0 = iptr0_aligned ? prev0 : *src0_vec_ptr; - src0_vec_ptr = iptr0_aligned ? src0_vec_ptr : src0_vec_ptr + 1; - bool iptr1_aligned = hexagon::is_addr_aligned(src1_vec_ptr); - HVX_Vector curr1 = iptr1_aligned ? prev1 : *src1_vec_ptr; - src1_vec_ptr = iptr1_aligned ? src1_vec_ptr : src1_vec_ptr + 1; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - sum = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_Vqf32_vmpy_VsfVsf(s0, s1), sum); - prev0 = curr0; - prev1 = curr1; - } - - const size_t leftover = count % kElementsPerVector; - const size_t leftover_bytes = leftover * sizeof(float); - if (leftover > 0) { - // handle the leftover elements - HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ? - *src0_vec_ptr : - prev0; - curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - - HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ? - *src1_vec_ptr : - prev1; - curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - - sum = Q6_Vqf32_vadd_Vqf32Vqf32( - Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr0, curr1), Q6_V_vzero(), leftover_bytes), sum); - } - - return hexagon::vec_reduction_f32(sum); -} - -// TODO: merge with vec_dot_product_f32_f32? -inline float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) { - constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(npu_device_fp16_t); - constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float); - - HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); - HVX_Vector * src0_vec_ptr_end = ((HVX_Vector *) src0) + (count / kElementsPerVector); - HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); - HVX_Vector prev0 = *src0_vec_ptr++; - HVX_Vector prev1 = *src1_vec_ptr++; - HVX_Vector sum_hi = Q6_V_vzero(); - HVX_Vector sum_lo = Q6_V_vzero(); - - while (src0_vec_ptr < src0_vec_ptr_end) { - HVX_Vector curr0 = *src0_vec_ptr++; - HVX_Vector curr1 = *src1_vec_ptr++; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(s0, s1); - sum_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(result), sum_hi); - sum_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(result), sum_lo); - prev0 = curr0; - prev1 = curr1; - } - - if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) { - // handle the last vector - // see also: - // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 - // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c - bool iptr0_aligned = hexagon::is_addr_aligned(src0_vec_ptr); - HVX_Vector curr0 = iptr0_aligned ? prev0 : *src0_vec_ptr; - src0_vec_ptr = iptr0_aligned ? src0_vec_ptr : src0_vec_ptr + 1; - bool iptr1_aligned = hexagon::is_addr_aligned(src1_vec_ptr); - HVX_Vector curr1 = iptr1_aligned ? prev1 : *src1_vec_ptr; - src1_vec_ptr = iptr1_aligned ? src1_vec_ptr : src1_vec_ptr + 1; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(s0, s1); - sum_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(result), sum_hi); - sum_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(result), sum_lo); - prev0 = curr0; - prev1 = curr1; - } - - const size_t leftover = count % kElementsPerVector; - const size_t leftover_bytes = leftover * sizeof(npu_device_fp16_t); - if (leftover > 0) { - // handle the leftover elements - HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ? - *src0_vec_ptr : - prev0; - curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - - HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ? - *src1_vec_ptr : - prev1; - curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - - HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(curr0, curr1); - - // TODO: can we do this better? - if (leftover > kFloatsPerVector) { - sum_hi = Q6_Vqf32_vadd_Vqf32Vqf32( - Q6_V_valign_VVR(Q6_V_hi_W(result), Q6_V_vzero(), (leftover % kFloatsPerVector) * sizeof(float)), - sum_hi); - sum_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(result), sum_lo); - } else { - sum_lo = Q6_Vqf32_vadd_Vqf32Vqf32( - Q6_V_valign_VVR(Q6_V_lo_W(result), Q6_V_vzero(), leftover * sizeof(float)), sum_lo); - } - } - - return hexagon::vec_reduction_f32(Q6_Vqf32_vadd_Vqf32Vqf32(sum_hi, sum_lo)); -} - template struct get_data_type {}; template struct get_data_type { @@ -175,29 +20,26 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso const bool is_quantized = hexagon::is_quantized_type(src0->get_type()); const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0); - auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).dequantize_row; + auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; if (is_quantized && dequantize_row_func == nullptr) { DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type()); return; } - const auto r02 = src1->get_ne(2) / src0->get_ne(2); - const auto r03 = src1->get_ne(3) / src0->get_ne(3); - const auto * src0_ptr = reinterpret_cast(src0->get_read_buffer()); - const auto * src1_ptr = reinterpret_cast(src1->get_read_buffer()); - auto * dst_ptr = reinterpret_cast(dst->get_write_buffer()); - const auto total_planes = dst->get_ne(3) * dst->get_ne(2); + const auto r02 = src1->get_ne(2) / src0->get_ne(2); + const auto r03 = src1->get_ne(3) / src0->get_ne(3); + const auto total_planes = dst->get_ne(3) * dst->get_ne(2); auto start_end_plane = std::pair{ 0, total_planes }; auto start_end_row = std::pair{ 0, dst->get_ne(1) }; auto start_end_element = std::pair{ 0, dst->get_ne(0) }; - if (total_planes >= params->tcnt) { - start_end_plane = hexagon::get_thread_work_slice(total_planes, params->tidx, params->tcnt); - } else if (dst->get_ne(1) >= params->tcnt) { - start_end_row = hexagon::get_thread_work_slice(dst->get_ne(1), params->tidx, params->tcnt); + if (total_planes >= params->get_thread_count()) { + start_end_plane = params->get_work_slice(total_planes); + } else if (dst->get_ne(1) >= params->get_thread_count()) { + start_end_row = params->get_work_slice(dst->get_ne(1)); } else { - start_end_element = hexagon::get_thread_work_slice(dst->get_ne(0), params->tidx, params->tcnt); + start_end_element = params->get_work_slice(dst->get_ne(0)); } if (start_end_plane.second <= start_end_plane.first || start_end_row.second <= start_end_row.first || @@ -218,7 +60,7 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso bool is_mem_cache = false; if (is_quantized) { src0_plane_slice_row_count = - std::min(params->vtcm_quota_size / src0_actual_row_size, src0_plane_slice_row_count); + std::min(params->get_vtcm_quota_size() / src0_actual_row_size, src0_plane_slice_row_count); src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count; src0_plane_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size); if (src0_plane_cache_ptr == nullptr) { @@ -238,7 +80,17 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso src0_plane_cache_size); const size_t valid_row_bytes = src1->get_ne(0) * sizeof(data_type); - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(dst, params->tidx, dequant); + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(dst, params->get_thread_index(), dequant); + + uint8_t * dst_ptr = dst->get_write_buffer(); + if (!dst_ptr) { + DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst, + hexagon::get_type_name(dst->get_type())); + return; + } + + const uint8_t * src0_ptr = src0->get_read_buffer(); + const uint8_t * src1_ptr = src1->get_read_buffer(); for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) { const auto i3 = ip / dst->get_ne(2); const auto i2 = ip - i3 * dst->get_ne(2); @@ -289,6 +141,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso } } } + + dst->release_write_buffer(); // mark the output tensor as modified } bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) { @@ -299,7 +153,7 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n } const auto type_traits = hexagon::get_type_traits(src0.type); - if (!type_traits.is_quantized || type_traits.dequantize_row == nullptr) { + if (!type_traits.is_quantized || type_traits.to_float == nullptr) { DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src0 is not quantized\n", hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type)); return false; @@ -311,7 +165,7 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n return false; } - const auto vtcm_thread_quota_size = hexagon::vtcm_mem::get_total_size() / hexagon::kMaxThreadCount; + const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota(); if (src0.ne[0] * sizeof(hexagon::dequantized_element_type) > vtcm_thread_quota_size) { DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n", hexagon::get_type_name(src0.type), (long) src0.ne[0], vtcm_thread_quota_size); @@ -339,14 +193,13 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) { return true; // skip if no src } - // TODO: array? switch (src1->get_type()) { case NPU_DATA_TYPE_F32: - mul_mat_impl(src0, src1, out, params); + mul_mat_impl(src0, src1, out, params); return true; case NPU_DATA_TYPE_F16: - mul_mat_impl(src0, src1, out, params); + mul_mat_impl(src0, src1, out, params); return true; default: break; @@ -356,18 +209,25 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) { return false; } -bool is_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1, - const npu_device_tensor_spec & dst, npu_device_tensor_op op) { +bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, size_t src_len) { if (op != NPU_OP_MUL_MAT) { DEVICE_LOG_DEBUG("op is not MUL_MAT: %d\n", op); return false; } - if (dst.type != NPU_DATA_TYPE_F32) { - DEVICE_LOG_DEBUG("[%s]dst type is not F32: %s\n", op_get_name(op), get_type_name(dst.type)); + if (!dst || !srcs || src_len < 2) { + DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op)); return false; } + if (dst->type != NPU_DATA_TYPE_F32) { + DEVICE_LOG_DEBUG("[%s]dst type is not F32: %s\n", op_get_name(op), get_type_name(dst->type)); + return false; + } + + const auto & src0 = srcs[0]; + const auto & src1 = srcs[1]; if (src0.type != src1.type) { #ifdef GGML_HEXAGON_ENABLE_QUANTIZED_TENSORS if (!is_quantized_mul_mat_supported(src0, src1)) { @@ -380,15 +240,15 @@ bool is_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_ #endif } - if (src0.ne[0] != src1.ne[0] || src0.ne[1] != dst.ne[0]) { + if (src0.ne[0] != src1.ne[0] || src0.ne[1] != dst->ne[0]) { DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[0], (long) src0.ne[1], (long) src1.ne[0], (long) src1.ne[1]); return false; } - if (src1.ne[1] != dst.ne[1] || src1.ne[2] != dst.ne[2] || src1.ne[3] != dst.ne[3]) { + if (src1.ne[1] != dst->ne[1] || src1.ne[2] != dst->ne[2] || src1.ne[3] != dst->ne[3]) { DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n", op_get_name(op), - (long) src1.ne[2], (long) src1.ne[3], (long) dst.ne[2], (long) dst.ne[3]); + (long) src1.ne[2], (long) src1.ne[3], (long) dst->ne[2], (long) dst->ne[3]); return false; } diff --git a/ggml/src/ggml-qnn/npu/device/op_mul_mat.hpp b/ggml/src/ggml-qnn/npu/device/op_mul_mat.hpp index 8cf41e0a99..434406f930 100644 --- a/ggml/src/ggml-qnn/npu/device/op_mul_mat.hpp +++ b/ggml/src/ggml-qnn/npu/device/op_mul_mat.hpp @@ -7,64 +7,8 @@ namespace hexagon { -inline size_t unaligned_bytes(const void * addr) { - return ((size_t) addr) & kAlignMask; -} - -inline bool is_addr_aligned(void * addr) { - return unaligned_bytes(addr) == 0; -} - -inline void l2fetch(const void * p, uint32_t stride, uint32_t width, uint32_t height, uint32_t dir) { - uint64_t control = HEXAGON_V64_CREATE_H(dir, stride, width, height); - __asm__ __volatile__(" l2fetch(%0,%1) " : : "r"(p), "r"(control)); -} - -inline void l2fetch_row(const uint8_t * curr_row, size_t bytes) { - // TODO: should we use small kL2FetchAheadVectors? - int32_t l2fetch_vectors = Q6_R_min_RR(bytes / kBytesPerVector, kL2FetchAheadVectors); - hexagon::l2fetch(curr_row, kBytesPerVector, kBytesPerVector, l2fetch_vectors, 0); -} - -inline float get_flt0_from_fltv(HVX_Vector vect) { - // See also: tools\HEXAGON_Tools\8.6.07\Examples\StandAlone_Applications\QFloat\QFloat.c - - union { - int32_t i; - float f; - } cvt; - - cvt.i = vect[0]; - return cvt.f; -} - -inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) { - constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float); - static_assert(kFloatsPerVector == 32 || kFloatsPerVector == 16, "kFloatsPerVector should be 16 or 32"); - - // TODO: do we have a better way to do the reduction? - switch (kFloatsPerVector) { - default: - case 32: - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 16 * sizeof(float))); - // fallthrough - case 16: - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 8 * sizeof(float))); - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 4 * sizeof(float))); - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 2 * sizeof(float))); - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, sizeof(float))); - break; - } - - return sums; -} - -inline float vec_reduction_f32(HVX_Vector sums) { - return hexagon::get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec_reduction_qf32(sums))); -} - bool mul_mat_f32(tensor * out, compute_params * params); -bool is_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1, - const npu_device_tensor_spec & dst, npu_device_tensor_op op); +bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, size_t src_len); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/op_types.hpp b/ggml/src/ggml-qnn/npu/device/op_types.hpp index 153bbab058..bad83ad95e 100644 --- a/ggml/src/ggml-qnn/npu/device/op_types.hpp +++ b/ggml/src/ggml-qnn/npu/device/op_types.hpp @@ -9,46 +9,12 @@ #include "hexagon_npu.h" #include "tensor.hpp" +#include "thread_pool.hpp" #include "util.hpp" -#include "vtcm_mem.hpp" +#include "vec_ops.hpp" namespace hexagon { -struct compute_params { - const size_t tidx; - const size_t tcnt; - const size_t vtcm_quota_size; - const float * f16_to_f32_table; - std::unique_ptr vtcm_cache; - std::unique_ptr mem_cache; - size_t mem_cache_size = 0; - - uint8_t * get_vtcm_cache(size_t size) { - if (!vtcm_cache || vtcm_cache->get_size() < size) { - vtcm_cache = std::make_unique(size, false); - } - - if (!vtcm_cache->is_valid()) { - return nullptr; - } - - return vtcm_cache->get_mem(); - } - - uint8_t * get_mem_cache(size_t size) { - if (!mem_cache || mem_cache_size < size) { - mem_cache = std::make_unique(size + 256); - mem_cache_size = mem_cache ? size : 0; - } - - return mem_cache.get(); - } -}; - -typedef bool (*compute_func_type)(tensor * dst, compute_params * params); -typedef bool (*op_is_supported_func_type)(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1, - const npu_device_tensor_spec & dst, npu_device_tensor_op op); - inline constexpr std::pair get_thread_work_slice(int64_t total, size_t tidx, size_t tcnt) { if (total <= 0 || tidx >= tcnt) { return { 0, 0 }; // No work for this thread @@ -72,9 +38,27 @@ inline constexpr std::pair get_thread_work_slice(int64_t total return { start, std::min(end, total) }; } -constexpr const size_t kBytesPerVector = sizeof(HVX_Vector); // 128 for v73 -constexpr const size_t kAlignMask = kBytesPerVector - 1; -constexpr const size_t kL2CacheSize = 8 * 1024; // // 8KB L2 cache -constexpr const size_t kL2FetchAheadVectors = kL2CacheSize / kBytesPerVector; +struct compute_params { + default_thread_pool::thread_params * const thread_params; + const float * f16_to_f32_table; + + uint8_t * get_vtcm_cache(size_t size) { return thread_params->get_vtcm_cache(size); } + + uint8_t * get_mem_cache(size_t size) { return thread_params->get_mem_cache(size); } + + std::pair get_work_slice(int64_t total) const { + return get_thread_work_slice(total, thread_params->tidx, thread_params->tcnt); + } + + size_t get_vtcm_quota_size() const { return thread_params->vtcm_quota_size; } + + size_t get_thread_count() const { return thread_params->tcnt; } + + size_t get_thread_index() const { return thread_params->tidx; } +}; + +typedef bool (*compute_func_type)(tensor * dst, compute_params * params); +typedef bool (*op_is_supported_func_type)(npu_device_tensor_op op, const npu_device_tensor_spec * dst, + const npu_device_tensor_spec * srcs, size_t src_len); } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/quants.cpp b/ggml/src/ggml-qnn/npu/device/quants.cpp deleted file mode 100644 index 67e77c2fc2..0000000000 --- a/ggml/src/ggml-qnn/npu/device/quants.cpp +++ /dev/null @@ -1,213 +0,0 @@ -#include "quants.hpp" - -#include - -#include - -#include "op_types.hpp" // TODO: remove this include - -static_assert(sizeof(npu_device_block_q4_K) == - 2 * sizeof(npu_device_fp16_t) + QUANT_K_SCALE_SIZE + QUANT_K_BLOCK_SIZE / 2, - "wrong q4_K block size/padding"); - -static_assert(sizeof(npu_device_block_q4_0) == sizeof(npu_device_fp16_t) + QUANT_BLOCK_SIZE / 2, - "wrong q4_0 block size/padding"); - -static_assert(sizeof(npu_device_block_q8_0) == sizeof(npu_device_fp16_t) + QUANT_BLOCK_SIZE, - "wrong q8_0 block size/padding"); - -namespace { - -inline HVX_Vector vmemu(const void * unaligned_ptr) { - HVX_Vector ret = *reinterpret_cast(unaligned_ptr); - return ret; -} - -inline float to_float(const npu_device_fp16_t src) { - return reinterpret_cast(src); -} - -template inline HVX_Vector load_block_generic(const _TBlock & src) { - uint8_t buffer[hexagon::kBytesPerVector]; - - static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding"); - static_assert(sizeof(buffer) >= sizeof(src.qs), "wrong q4_0 block size/padding"); - - memcpy(&buffer[0], src.qs, sizeof(src.qs)); - return *reinterpret_cast(buffer); -} - -template inline HVX_Vector load_dual_block_generic(const _TBlock & src1, const _TBlock & src2) { - uint8_t buffer[hexagon::kBytesPerVector]; - - static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding"); - static_assert(sizeof(buffer) >= sizeof(src1.qs) * 2, "wrong q4_0 block size/padding"); - - memcpy(&buffer[0], src1.qs, sizeof(src1.qs)); - memcpy(&buffer[sizeof(src1.qs)], src2.qs, sizeof(src2.qs)); - return *reinterpret_cast(buffer); -} - -inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { - if (j < 4) { - *d = q[j] & 63; - *m = q[j + 4] & 63; - } else { - *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); - *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); - } -} - -void dequantize_row_q8_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) { - constexpr const int qk = QUANT_BLOCK_SIZE; - static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); - - const int nb = count / qk; - const auto * src_ptr = reinterpret_cast(src); - HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access - - for (int i = 0; i < nb; i++) { - const auto & src = src_ptr[i]; - HVX_Vector d = Q6_Vh_vsplat_R(src.d); - - HVX_Vector q_lo = load_block_generic(src); - HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo); - q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q)); - q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q)); - q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d); - out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q)); - } -} - -void dequantize_row_q4_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) { - constexpr const int qk = QUANT_BLOCK_SIZE; - static_assert(qk % 2 == 0, "qk must be even"); - static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); - constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); - - const int nb = count / qk; - const auto * src_ptr = reinterpret_cast(src); - HVX_Vector mask = Q6_Vb_vsplat_R(0x0F); - HVX_Vector minus = Q6_Vb_vsplat_R(8); - HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access - - const int loop_count = nb - (nb % 2); - for (int i = 0; i < loop_count; i += 2) { - const auto & src1 = src_ptr[i]; - const auto & src2 = src_ptr[i + 1]; - - HVX_Vector d1 = Q6_Vh_vsplat_R(src1.d); - HVX_Vector d2 = Q6_Vh_vsplat_R(src2.d); - d1 = Q6_V_valign_VVR(d1, Q6_V_vzero(), hexagon::kBytesPerVector / 2); - d1 = Q6_V_valign_VVR(d2, d1, hexagon::kBytesPerVector / 2); - HVX_Vector d = Q6_Vh_vshuff_Vh(d1); - - HVX_Vector q_lo = load_dual_block_generic(src1, src2); - HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4); - HVX_VectorPair q = Q6_W_vshuff_VVR(q_hi, Q6_V_vand_VV(q_lo, mask), kSizeOfQs); - q_lo = Q6_V_valign_VVR(Q6_V_lo_W(q), Q6_V_vzero(), hexagon::kBytesPerVector / 2); - q_lo = Q6_V_valign_VVR(Q6_V_hi_W(q), q_lo, hexagon::kBytesPerVector / 2); - q_lo = Q6_Vb_vshuff_Vb(q_lo); - q_lo = Q6_Vb_vsub_VbVb(q_lo, minus); - q = Q6_Wh_vunpack_Vb(q_lo); - q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q)); - q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d); - out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q)); - out[i + 1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(q)); - } - - if (loop_count < nb) { - const auto & curr_blk = src_ptr[nb - 1]; - HVX_Vector d = Q6_Vh_vsplat_R(curr_blk.d); - - HVX_Vector q_lo = load_block_generic(curr_blk); - HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4); - q_lo = Q6_V_valign_VVR(Q6_V_vand_VV(q_lo, mask), Q6_V_vzero(), sizeof(curr_blk.qs)); - q_lo = Q6_V_valign_VVR(q_hi, q_lo, hexagon::kBytesPerVector - sizeof(curr_blk.qs)); - q_lo = Q6_Vb_vsub_VbVb(q_lo, minus); - - HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo); - q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q)); - q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q)); - q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d); - out[nb - 1] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q)); - } -} - -void dequantize_row_q4_K(const void * src, float * dst, size_t count, const float * f16_to_f32_table) { - const int nb = count / QUANT_K_BLOCK_SIZE; - const auto * src_ptr = reinterpret_cast(src); - - // TODO: use intrinsics - for (int i = 0; i < nb; i++) { - const uint8_t * q = src_ptr[i].qs; - - const float d = f16_to_f32_table[src_ptr[i].d]; - const float min = f16_to_f32_table[src_ptr[i].dmin]; - - int is = 0; - uint8_t sc = 0; - uint8_t m = 0; - const auto * scales = src_ptr[i].scales; - for (int j = 0; j < QUANT_K_BLOCK_SIZE; j += 64) { - get_scale_min_k4(is + 0, scales, &sc, &m); - const float d1 = d * sc; - const float m1 = min * m; - get_scale_min_k4(is + 1, scales, &sc, &m); - const float d2 = d * sc; - const float m2 = min * m; - for (int l = 0; l < 32; ++l) { - dst[0] = d1 * (q[l] & 0xF) - m1; - dst[32] = d2 * ((q[l] >> 4) & 0xF) - m2; - dst++; - } - dst += 32; - q += 32; - is += 2; - } - } -} - -constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = { - { NPU_DATA_TYPE_F32, "F32", 1, false, nullptr }, - { NPU_DATA_TYPE_F16, "F16", 1, false, nullptr }, - { NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, true, dequantize_row_q8_0 }, - { NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, true, dequantize_row_q4_0 }, - { NPU_DATA_TYPE_Q4_K, "Q4_K", QUANT_K_BLOCK_SIZE, true, dequantize_row_q4_K }, -}; - -static_assert(std::size(kDeviceTypeTraits) == NPU_DATA_TYPE_COUNT, - "kDeviceTypeTraits size mismatch with npu_device_tensor_data_type enum"); -static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F32].type == NPU_DATA_TYPE_F32, - "kDeviceTypeTraits F32 type mismatch with npu_device_tensor_data_type enum"); -static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F16].type == NPU_DATA_TYPE_F16, - "kDeviceTypeTraits F16 type mismatch with npu_device_tensor_data_type enum"); -static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q8_0].type == NPU_DATA_TYPE_Q8_0, - "kDeviceTypeTraits Q8_0 type mismatch with npu_device_tensor_data_type enum"); -static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_0].type == NPU_DATA_TYPE_Q4_0, - "kDeviceTypeTraits Q4_0 type mismatch with npu_device_tensor_data_type enum"); -static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_K].type == NPU_DATA_TYPE_Q4_K, - "kDeviceTypeTraits Q4_K type mismatch with npu_device_tensor_data_type enum"); - -} // namespace - -namespace hexagon { - -bool init_f16_f32_table(float * table, size_t count) { - constexpr const size_t kTableSize = (1U << 16); - if (count < kTableSize) { - return false; - } - - for (size_t i = 0; i < count; ++i) { - table[i] = to_float(i); - } - - return true; -} - -const device_type_traits & get_type_traits(npu_device_tensor_data_type type) { - return kDeviceTypeTraits[type]; -} - -} // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/tensor.hpp b/ggml/src/ggml-qnn/npu/device/tensor.hpp index 7e980d8402..bad260e5e5 100644 --- a/ggml/src/ggml-qnn/npu/device/tensor.hpp +++ b/ggml/src/ggml-qnn/npu/device/tensor.hpp @@ -3,6 +3,8 @@ #include #include +#include + #include "hexagon_npu.h" #include "util.hpp" @@ -23,7 +25,7 @@ class tensor { } _data = static_cast(mmap_address); - DEVICE_LOG_INFO("tensor(%p[%ldx%ldx%ldx%ld]), fd: %d, offset: %zu, mmap_address: %p, phy_address: 0x%lx\n", + DEVICE_LOG_INFO("tensor(%p[%ldx%ldx%ldx%ld]), fd: %d, offset: %zu, mmap_addr: %p, phy_addr: 0x%lx\n", (void *) this, (long) _info.ne[0], (long) _info.ne[1], (long) _info.ne[2], (long) _info.ne[3], _info.buffer_fd, _info.offset, (void *) mmap_address, phy_address); } @@ -47,14 +49,14 @@ class tensor { void invalidate() const { if (_data) { qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size, - QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE); + QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE); } } void update_config(const npu_device_tensor_update_config & config) { static_assert(sizeof(_op_params) == sizeof(config.params), "op params size mismatch"); - _info.op = config.op; + _op_type = config.op; memcpy(_op_params, config.params, sizeof(_op_params)); for (size_t i = 0; i < DEVICE_TENSOR_MAX_SRC; ++i) { auto src_handle = config.src_handles[i]; @@ -76,7 +78,12 @@ class tensor { const size_t get_nb(size_t index) const { return _info.nb[index]; } - npu_device_tensor_op get_op() const { return _info.op; } + const bool is_permuted() const { + // Check if the tensor is permuted by comparing the nb values + return is_transposed_or_permuted(_info.nb); + } + + npu_device_tensor_op get_op() const { return _op_type; } template const _TyParam get_op_param(size_t index) const { static_assert(sizeof(_TyParam) <= sizeof(_op_params), "_op_param type size exceeds op params size"); @@ -95,19 +102,34 @@ class tensor { npu_device_tensor_data_type get_type() const { return _info.type; } const uint8_t * get_read_buffer() const { - invalidate(); + if (!_info.is_constant && _has_modified) { + invalidate(); + const_cast(this)->_has_modified = false; // TODO: avoid const_cast + } + return _data + _info.offset; } - uint8_t * get_write_buffer() const { return _data + _info.offset; } + uint8_t * get_write_buffer() const { + if (_info.is_constant) { + DEVICE_LOG_ERROR("Attempt to write to a constant tensor: %p", (void *) this); + return nullptr; // Do not allow writing to constant tensors + } + + return _data + _info.offset; + } + + void release_write_buffer() { _has_modified = true; } bool is_valid() const { return _data != nullptr; } private: npu_device_tensor_config _info = {}; + npu_device_tensor_op _op_type = NPU_OP_COUNT; int32_t _op_params[kMaxParamsCount] = {}; tensor * _src[kMaxTensorSrc] = {}; uint8_t * _data = nullptr; + std::atomic_bool _has_modified = false; DISABLE_COPY_AND_MOVE(tensor); }; diff --git a/ggml/src/ggml-qnn/npu/device/thread_pool.hpp b/ggml/src/ggml-qnn/npu/device/thread_pool.hpp index 9a525213c9..9661c00670 100644 --- a/ggml/src/ggml-qnn/npu/device/thread_pool.hpp +++ b/ggml/src/ggml-qnn/npu/device/thread_pool.hpp @@ -8,6 +8,7 @@ #include #include "util.hpp" +#include "vtcm_mem.hpp" namespace hexagon { @@ -78,28 +79,65 @@ template class qurt_thread { using qurt_thread_ptr = std::unique_ptr>; -template class thread_pool { - static_assert(_thread_count > 1, "Thread count must be greater than 1"); - constexpr const static size_t kMaxSubThreadCount = _thread_count - 1; +template class thread_pool { + static_assert(_ThreadCount > 1, "Thread count must be greater than 1"); + constexpr const static size_t kMaxThreadCount = _ThreadCount; + constexpr const static size_t kMaxSubThreadCount = _ThreadCount - 1; public: typedef qurt_thread thread_type; - typedef void (*task_type)(thread_pool * pool, size_t thread_idx, size_t thread_count, void * arg); + + struct thread_params { + size_t tidx; + const size_t tcnt = kMaxThreadCount; + thread_pool * pool = nullptr; + size_t vtcm_quota_size; + + std::unique_ptr vtcm_cache; + std::unique_ptr mem_cache; + size_t mem_cache_size = 0; + + uint8_t * get_vtcm_cache(size_t size) { + if (!vtcm_cache || vtcm_cache->get_size() < size) { + DEVICE_SCOPED_PERFORMANCE_TRACKER("[thread_params]get_vtcm_cache, size: %zu, tidx: %zu", size, tidx); + vtcm_cache.reset(); // reset the cache to create a new one + vtcm_cache = std::make_unique(size, false); + } + + if (!vtcm_cache->is_valid()) { + return nullptr; + } + + return vtcm_cache->get_mem(); + } + + uint8_t * get_mem_cache(size_t size) { + if (!mem_cache || mem_cache_size < size) { + mem_cache.reset(); // reset the cache to create a new one + mem_cache = std::make_unique(size + 256); + mem_cache_size = mem_cache ? size : 0; + } + + return mem_cache.get(); + } + }; + + typedef void (*task_type)(thread_pool * pool, thread_params * param, void * arg); thread_pool() { - std::string thread_name_base = "thread_pool_"; + for (size_t i = 0; i < kMaxThreadCount; ++i) { + _thread_params[i].tidx = i; + _thread_params[i].vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount; + _thread_params[i].pool = this; + } + qurt_barrier_init(&_pending, kMaxSubThreadCount + 1); qurt_barrier_init(&_completed, kMaxSubThreadCount + 1); - const auto priority = qurt_thread_get_priority(qurt_thread_get_id()); + const auto priority = qurt_thread_get_priority(qurt_thread_get_id()); + std::string thread_name_base = "thread_pool_"; for (size_t i = 0; i < kMaxSubThreadCount; ++i) { - auto & thread_arg = _thread_args[i]; - thread_arg.pool = this; - thread_arg.thread_idx = i + 1; - auto thread = std::make_unique( - thread_name_base + std::to_string(i), - reinterpret_cast(&thread_pool::thread_func_impl), &thread_arg, - priority); + thread_name_base + std::to_string(i), &thread_pool::thread_func_impl, &_thread_params[i + 1], priority); if (!thread->is_valid()) { DEVICE_LOG_ERROR("Failed to create thread: %zu", i); // destroy all barriers and threads at destructor @@ -108,6 +146,7 @@ template class thread_pool { _threads[i] = std::move(thread); } + DEVICE_LOG_DEBUG("thread_pool.created: %zu", kMaxSubThreadCount); } @@ -130,60 +169,85 @@ template class thread_pool { return false; } +#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING + _task_begin_cycles = HAP_perf_get_qtimer_count(); +#endif + _task = task; _arg = arg; qurt_barrier_wait(&_pending); - task(this, 0, kMaxSubThreadCount + 1, arg); + task(this, &_thread_params[0], arg); DEVICE_LOG_DEBUG("main_thread.task_completed: 0"); qurt_barrier_wait(&_completed); _task = nullptr; _arg = nullptr; + +#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING + _task_begin_cycles = 0; +#endif + return true; } void sync_thread() { qurt_barrier_wait(&_completed); } - private: - struct thread_pool_arg { - thread_pool * pool = nullptr; - size_t thread_idx = 0; - }; + static size_t get_per_thread_vtcm_quota() { return vtcm_mem::get_total_size() / kMaxThreadCount; } - static void thread_func_impl(thread_type * thread, thread_pool_arg * arg) { + private: + static void thread_func_impl(thread_type * thread, void * arg) { NPU_UNUSED(thread); - DEVICE_LOG_DEBUG("thread_func_impl.start: %zu", arg->thread_idx); + auto * param = reinterpret_cast(arg); - auto & pool = *arg->pool; + DEVICE_LOG_DEBUG("thread_func_impl.start: %zu", param->tidx); + + auto & pool = *(param->pool); for (;;) { qurt_barrier_wait(&pool._pending); if (pool._thread_exit) { - DEVICE_LOG_DEBUG("thread_func_impl.exit: %zu", arg->thread_idx); + DEVICE_LOG_DEBUG("thread_func_impl.exit: %zu", param->tidx); break; } +#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING + auto task_begin_cycles = pool._task_begin_cycles.load(); + DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, prepare: %lluus", param->tidx, + static_cast( + HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles))); +#endif + auto task = pool._task; if (task) { - task(arg->pool, arg->thread_idx, kMaxSubThreadCount + 1, pool._arg); + task(param->pool, param, pool._arg); } - DEVICE_LOG_DEBUG("thread_func_impl.task_completed: %zu", arg->thread_idx); + DEVICE_LOG_DEBUG("thread_func_impl.task_completed: %zu", param->tidx); qurt_barrier_wait(&pool._completed); + +#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING + DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, task_end: %lluus", param->tidx, + static_cast( + HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles))); +#endif } - DEVICE_LOG_DEBUG("thread_func_impl.end: %zu", arg->thread_idx); + DEVICE_LOG_DEBUG("thread_func_impl.end: %zu", param->tidx); } std::atomic_bool _thread_exit = false; std::array _threads; - thread_pool_arg _thread_args[kMaxSubThreadCount] = {}; - qurt_barrier_t _pending = {}; - qurt_barrier_t _completed = {}; - task_type _task = nullptr; - void * _arg = nullptr; + qurt_barrier_t _pending = {}; + qurt_barrier_t _completed = {}; + thread_params _thread_params[kMaxThreadCount] = {}; + task_type _task = nullptr; + void * _arg = nullptr; + +#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING + std::atomic _task_begin_cycles = 0; +#endif DISABLE_COPY_AND_MOVE(thread_pool); }; diff --git a/ggml/src/ggml-qnn/npu/device/type_traits.cpp b/ggml/src/ggml-qnn/npu/device/type_traits.cpp new file mode 100644 index 0000000000..87d361819d --- /dev/null +++ b/ggml/src/ggml-qnn/npu/device/type_traits.cpp @@ -0,0 +1,467 @@ +#include "type_traits.hpp" + +#include + +#include + +#include "op_types.hpp" // TODO: remove this include +#include "vec_ops.hpp" + +static_assert(sizeof(npu_device_block_q4_k) == + 2 * sizeof(npu_device_fp16_t) + QUANT_K_SCALE_SIZE + QUANT_K_BLOCK_SIZE / 2, + "wrong q4_K block size/padding"); + +static_assert(sizeof(npu_device_block_q4_0) == sizeof(npu_device_fp16_t) + QUANT_BLOCK_SIZE / 2, + "wrong q4_0 block size/padding"); + +static_assert(sizeof(npu_device_block_q8_0) == sizeof(npu_device_fp16_t) + QUANT_BLOCK_SIZE, + "wrong q8_0 block size/padding"); + +namespace { + +inline float to_float(const npu_device_fp16_t src) { + return reinterpret_cast(src); +} + +inline npu_device_fp16_t to_fp16(const float src) { + __fp16 f16_value = static_cast<__fp16>(src); + return reinterpret_cast(f16_value); +} + +template inline HVX_Vector load_block_generic(const _TBlock & src) { + uint8_t buffer[hexagon::kBytesPerVector]; + + static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding"); + static_assert(sizeof(buffer) >= sizeof(src.qs), "wrong q4_0 block size/padding"); + + memcpy(&buffer[0], src.qs, sizeof(src.qs)); + return *reinterpret_cast(buffer); +} + +template inline HVX_Vector load_dual_block_generic(const _TBlock & src1, const _TBlock & src2) { + uint8_t buffer[hexagon::kBytesPerVector]; + + static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding"); + static_assert(sizeof(buffer) >= sizeof(src1.qs) * 2, "wrong q4_0 block size/padding"); + + memcpy(&buffer[0], src1.qs, sizeof(src1.qs)); + memcpy(&buffer[sizeof(src1.qs)], src2.qs, sizeof(src2.qs)); + return *reinterpret_cast(buffer); +} + +inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +inline int nearest_int(float fval) { + float val = fval + 12582912.f; + int i = reinterpret_cast(val); + return (i & 0x007fffff) - 0x00400000; +} + +float make_qkx2_quants(int n, int nmax, const float * x, const float * weights, uint8_t * L, float * the_min, + uint8_t * Laux, float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights[0]; + float sum_x = sum_w * x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) { + min = x[i]; + } + if (x[i] > max) { + max = x[i]; + } + float w = weights[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) { + min = 0; + } + if (max == min) { + for (int i = 0; i < n; ++i) { + L[i] = 0; + } + *the_min = -min; + return 0.f; + } + float iscale = nmax / (max - min); + float scale = 1 / iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * (x[i] - min)); + L[i] = std::max(0, std::min(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta * is + nmax) / (max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * (x[i] - min)); + l = std::max(0, std::min(nmax, l)); + Laux[i] = l; + float w = weights[i]; + sum_l += w * l; + sum_l2 += w * l * l; + sum_xl += w * l * x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l) / D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + +void quantize_row_fp16(const float * src, void * dst, size_t count, const float * f16_to_f32_table) { + auto * out = reinterpret_cast(dst); + // TODO: use hvx intrinsics for better performance + for (size_t i = 0; i < count; i++) { + out[i] = to_fp16(src[i]); + } +} + +void quantize_row_q8_0(const float * src, void * dst, size_t count, const float * f16_to_f32_table) { + const int nb = count / QUANT_BLOCK_SIZE; + auto * out = reinterpret_cast(dst); + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QUANT_BLOCK_SIZE; j++) { + const float v = src[i * QUANT_BLOCK_SIZE + j]; + amax = std::max(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f / d : 0.0f; + + out[i].d = to_fp16(d); + + for (int j = 0; j < QUANT_BLOCK_SIZE; ++j) { + const float x0 = src[i * QUANT_BLOCK_SIZE + j] * id; + + out[i].qs[j] = roundf(x0); + } + } +} + +void quantize_row_q4_0(const float * src, void * dst, size_t count, const float * f16_to_f32_table) { + constexpr const int qk = QUANT_BLOCK_SIZE; + + const int nb = count / qk; + auto * out = reinterpret_cast(dst); + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = src[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f / d : 0.0f; + + out[i].d = to_fp16(d); + + for (int j = 0; j < qk / 2; ++j) { + const float x0 = src[i * qk + 0 + j] * id; + const float x1 = src[i * qk + qk / 2 + j] * id; + + const uint8_t xi0 = std::min(15, (x0 + 8.5f)); + const uint8_t xi1 = std::min(15, (x1 + 8.5f)); + + out[i].qs[j] = xi0; + out[i].qs[j] |= xi1 << 4; + } + } +} + +void quantize_row_q4_K(const float * src, void * dst, size_t count, const float * f16_to_f32_table) { + const int nb = count / QUANT_K_BLOCK_SIZE; + auto * out = reinterpret_cast(dst); + + uint8_t L[QUANT_K_BLOCK_SIZE]; + uint8_t Laux[32]; + float weights[32]; + float mins[QUANT_K_BLOCK_SIZE / 32]; + float scales[QUANT_K_BLOCK_SIZE / 32]; + + for (int i = 0; i < nb; i++) { + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QUANT_K_BLOCK_SIZE / 32; ++j) { + //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) { + sum_x2 += src[32 * j + l] * src[32 * j + l]; + } + float av_x = sqrtf(sum_x2 / 32); + for (int l = 0; l < 32; ++l) { + weights[l] = av_x + fabsf(src[32 * j + l]); + } + scales[j] = + make_qkx2_quants(32, 15, src + 32 * j, weights, L + 32 * j, &mins[j], Laux, -1.f, 0.1f, 20, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f / max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f / max_min : 0.f; + for (int j = 0; j < QUANT_K_BLOCK_SIZE / 32; ++j) { + uint8_t ls = nearest_int(inv_scale * scales[j]); + uint8_t lm = nearest_int(inv_min * mins[j]); + ls = std::min(63, ls); + lm = std::min(63, lm); + if (j < 4) { + out[i].scales[j] = ls; + out[i].scales[j + 4] = lm; + } else { + out[i].scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4); + out[i].scales[j - 4] |= ((ls >> 4) << 6); + out[i].scales[j - 0] |= ((lm >> 4) << 6); + } + } + out[i].d = to_fp16(max_scale / 63.f); + out[i].dmin = to_fp16(max_min / 63.f); + + uint8_t sc, m; + for (int j = 0; j < QUANT_K_BLOCK_SIZE / 32; ++j) { + get_scale_min_k4(j, out[i].scales, &sc, &m); + const float d = f16_to_f32_table[out[i].d] * sc; + if (!d) { + continue; + } + const float dm = f16_to_f32_table[out[i].dmin] * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((src[32 * j + ii] + dm) / d); + l = std::max(0, std::min(15, l)); + L[32 * j + ii] = l; + } + } + + uint8_t * q = out[i].qs; + for (int j = 0; j < QUANT_K_BLOCK_SIZE; j += 64) { + for (int l = 0; l < 32; ++l) { + q[l] = L[j + l] | (L[j + l + 32] << 4); + } + q += 32; + } + + src += QUANT_K_BLOCK_SIZE; + } +} + +void dequantize_row_q8_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) { + constexpr const int qk = QUANT_BLOCK_SIZE; + static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); + + const int nb = count / qk; + const auto * src_ptr = reinterpret_cast(src); + HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access + + for (int i = 0; i < nb; i++) { + const auto & src = src_ptr[i]; + HVX_Vector d = Q6_Vh_vsplat_R(src.d); + + HVX_Vector q_lo = load_block_generic(src); + HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo); + q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q)); + q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q)); + q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d); + out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q)); + } +} + +void dequantize_row_q4_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) { + constexpr const int qk = QUANT_BLOCK_SIZE; + static_assert(qk % 2 == 0, "qk must be even"); + static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); + constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); + + const int nb = count / qk; + const auto * src_ptr = reinterpret_cast(src); + HVX_Vector mask = Q6_Vb_vsplat_R(0x0F); + HVX_Vector minus = Q6_Vb_vsplat_R(8); + HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access + + const int loop_count = nb - (nb % 2); + for (int i = 0; i < loop_count; i += 2) { + const auto & src1 = src_ptr[i]; + const auto & src2 = src_ptr[i + 1]; + + HVX_Vector d1 = Q6_Vh_vsplat_R(src1.d); + HVX_Vector d2 = Q6_Vh_vsplat_R(src2.d); + HVX_Vector d = Q6_Vh_vshuff_Vh(Q6_V_valign_VVR(d2, d1, hexagon::kBytesPerVector / 2)); + + HVX_Vector q_lo = load_dual_block_generic(src1, src2); + HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4); + HVX_VectorPair q = Q6_W_vshuff_VVR(q_hi, Q6_V_vand_VV(q_lo, mask), kSizeOfQs); + q_lo = Q6_V_valign_VVR(Q6_V_lo_W(q), Q6_V_vzero(), hexagon::kBytesPerVector / 2); + q_lo = Q6_V_valign_VVR(Q6_V_hi_W(q), q_lo, hexagon::kBytesPerVector / 2); + q_lo = Q6_Vb_vshuff_Vb(q_lo); + q_lo = Q6_Vb_vsub_VbVb(q_lo, minus); + q = Q6_Wh_vunpack_Vb(q_lo); + q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q)); + q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d); + out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q)); + out[i + 1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(q)); + } + + if (loop_count < nb) { + const auto & curr_blk = src_ptr[nb - 1]; + HVX_Vector d = Q6_Vh_vsplat_R(curr_blk.d); + + HVX_Vector q_lo = load_block_generic(curr_blk); + HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4); + q_lo = Q6_V_valign_VVR(Q6_V_vand_VV(q_lo, mask), Q6_V_vzero(), sizeof(curr_blk.qs)); + q_lo = Q6_V_valign_VVR(q_hi, q_lo, hexagon::kBytesPerVector - sizeof(curr_blk.qs)); + q_lo = Q6_Vb_vsub_VbVb(q_lo, minus); + + HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo); + q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q)); + q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q)); + q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d); + out[nb - 1] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q)); + } +} + +void dequantize_row_q4_K(const void * src, float * dst, size_t count, const float * f16_to_f32_table) { + const int nb = count / QUANT_K_BLOCK_SIZE; + const auto * src_ptr = reinterpret_cast(src); + + // TODO: use intrinsics + for (int i = 0; i < nb; i++) { + const uint8_t * q = src_ptr[i].qs; + + const float d = f16_to_f32_table[src_ptr[i].d]; + const float min = f16_to_f32_table[src_ptr[i].dmin]; + + int is = 0; + uint8_t sc = 0; + uint8_t m = 0; + const auto * scales = src_ptr[i].scales; + for (int j = 0; j < QUANT_K_BLOCK_SIZE; j += 64) { + get_scale_min_k4(is + 0, scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + get_scale_min_k4(is + 1, scales, &sc, &m); + const float d2 = d * sc; + const float m2 = min * m; + for (int l = 0; l < 32; ++l) { + dst[0] = d1 * (q[l] & 0xF) - m1; + dst[32] = d2 * ((q[l] >> 4) & 0xF) - m2; + dst++; + } + dst += 32; + q += 32; + is += 2; + } + } +} + +template struct dot_func_traits {}; + +template struct dot_func_traits { + using param_type = std::remove_const_t>; +}; + +template float wrap_dot_func(const void * src0, const void * src1, size_t count) { + using param_type = typename dot_func_traits::param_type; + return _Func(reinterpret_cast(src0), reinterpret_cast(src1), count); +} + +constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = { + { NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, nullptr, nullptr, + wrap_dot_func }, + { NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, nullptr, quantize_row_fp16, + wrap_dot_func }, + { NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0, + quantize_row_q8_0 }, + { NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0, + quantize_row_q4_0 }, + { NPU_DATA_TYPE_Q4_K, "Q4_K", QUANT_K_BLOCK_SIZE, sizeof(npu_device_block_q4_k), true, dequantize_row_q4_K, + quantize_row_q4_K }, +}; + +static_assert(std::size(kDeviceTypeTraits) == NPU_DATA_TYPE_COUNT, + "kDeviceTypeTraits size mismatch with npu_device_tensor_data_type enum"); +static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F32].type == NPU_DATA_TYPE_F32, + "kDeviceTypeTraits F32 type mismatch with npu_device_tensor_data_type enum"); +static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F16].type == NPU_DATA_TYPE_F16, + "kDeviceTypeTraits F16 type mismatch with npu_device_tensor_data_type enum"); +static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q8_0].type == NPU_DATA_TYPE_Q8_0, + "kDeviceTypeTraits Q8_0 type mismatch with npu_device_tensor_data_type enum"); +static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_0].type == NPU_DATA_TYPE_Q4_0, + "kDeviceTypeTraits Q4_0 type mismatch with npu_device_tensor_data_type enum"); +static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_K].type == NPU_DATA_TYPE_Q4_K, + "kDeviceTypeTraits Q4_K type mismatch with npu_device_tensor_data_type enum"); + +} // namespace + +namespace hexagon { + +bool init_f16_f32_table(float * table, size_t count) { + constexpr const size_t kTableSize = (1U << 16); + if (count < kTableSize) { + return false; + } + + for (size_t i = 0; i < count; ++i) { + table[i] = to_float(i); + } + + return true; +} + +const device_type_traits & get_type_traits(npu_device_tensor_data_type type) { + return kDeviceTypeTraits[type]; +} + +} // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/quants.hpp b/ggml/src/ggml-qnn/npu/device/type_traits.hpp similarity index 65% rename from ggml/src/ggml-qnn/npu/device/quants.hpp rename to ggml/src/ggml-qnn/npu/device/type_traits.hpp index 6006cd22e9..1a0b1665ae 100644 --- a/ggml/src/ggml-qnn/npu/device/quants.hpp +++ b/ggml/src/ggml-qnn/npu/device/type_traits.hpp @@ -7,14 +7,20 @@ namespace hexagon { bool init_f16_f32_table(float * table, size_t count); +typedef void (*quantize_row_type)(const float * src, void * dst, size_t count, const float * f16_to_f32_table); typedef void (*dequantize_row_type)(const void * src, float * dst, size_t count, const float * f16_to_f32_table); +typedef float (*vec_dot_type)(const void * src0, const void * src1, size_t count); struct device_type_traits { npu_device_tensor_data_type type; const char * type_name; int64_t blck_size; + size_t type_size; bool is_quantized; - dequantize_row_type dequantize_row; + + dequantize_row_type to_float; + quantize_row_type from_float; + vec_dot_type vec_dot; }; const device_type_traits & get_type_traits(npu_device_tensor_data_type type); @@ -44,10 +50,10 @@ inline const char * get_type_name(npu_device_tensor_data_type type) { #ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING namespace hexagon { -inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx, const char * sub_proc_log_prefix = nullptr) { +inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx) { auto * src0 = op->get_src(0); auto * src1 = op->get_src(1); - char buffer[512]; + char buffer[1024]; if (src1 == nullptr) { snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s], tidx: %zu", op_get_name(op->get_op()), src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3), get_type_name(src0->get_type()), @@ -58,7 +64,7 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx, const char * sub get_type_name(src0->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2), src1->get_ne(3), get_type_name(src1->get_type()), tidx); } - return npu_scoped_timer<512>(buffer, sub_proc_log_prefix); + return npu_scoped_timer<1024>(buffer); } } // namespace hexagon @@ -67,14 +73,23 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx, const char * sub auto __npu_op_timer_##__LINE__ = hexagon::make_scoped_op_perf_timer(op, tidx) # define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(op, tidx, sub_prefix) \ - auto __npu_op_timer_##sub_prefix = hexagon::make_scoped_op_perf_timer(op, tidx, #sub_prefix) + auto __npu_op_timer_##sub_prefix = hexagon::make_scoped_op_perf_timer(op, tidx) -# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(sub_prefix) \ - hexagon::npu_sub_process_scoped_timer \ - __npu_op_sub_timer##sub_prefix(__npu_op_timer_##sub_prefix) +# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(sub_prefix) \ + hexagon::npu_sub_process_scoped_timer \ + __npu_op_sub_timer##sub_prefix(__npu_op_timer_##sub_prefix, #sub_prefix) + +# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(op, tidx, tracker_name) \ + auto __npu_op_timer_##tracker_name = hexagon::make_scoped_op_perf_timer(op, tidx) + +# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(tracker_name, idx, sub_prefix) \ + hexagon::npu_sub_process_scoped_timer \ + __npu_op_sub_timer##sub_prefix(__npu_op_timer_##tracker_name, #sub_prefix) #else -# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(op, tidx) ((void) 0) -# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(op, tidx, sub_prefix) ((void) 0) -# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(sub_prefix) ((void) 0) +# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(op, tidx) ((void) 0) +# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(op, tidx, sub_prefix) ((void) 0) +# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(sub_prefix) ((void) 0) +# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(op, tidx, tracker_name) ((void) 0) +# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(tracker_name, idx, sub_prefix) ((void) 0) #endif diff --git a/ggml/src/ggml-qnn/npu/device/util.hpp b/ggml/src/ggml-qnn/npu/device/util.hpp index 3ae7f100de..8c819fe583 100644 --- a/ggml/src/ggml-qnn/npu/device/util.hpp +++ b/ggml/src/ggml-qnn/npu/device/util.hpp @@ -52,11 +52,18 @@ inline constexpr const char * op_get_name(npu_device_tensor_op op) { return "MUL"; case NPU_OP_RMS_NORM: return "RMS_NORM"; + case NPU_OP_FLASH_ATTN: + return "FLASH_ATTN_EXT"; default: return "UNKNOWN"; } } +inline bool is_transposed_or_permuted(const npu_device_nb_type & nb) { + // Check if the tensor is transposed or permuted + return (nb[0] > nb[1]) || (nb[1] > nb[2]) || (nb[2] > nb[3]); +} + class power_utils { public: power_utils() { @@ -160,16 +167,22 @@ class power_utils { #ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING +struct sub_process_data { + char log_prefix[32] = {}; + uint64_t proc_cycles = 0; + uint64_t proc_pcycles = 0; + uint64_t proc_count = 0; +}; + template class npu_scoped_timer { public: - enum { kBufferCount = _buffer_count }; + enum { + kBufferCount = _buffer_count, + kSubProcCount = 4, + }; - explicit npu_scoped_timer(const char * log_prefix, const char * sub_proc_log_prefix) { + explicit npu_scoped_timer(const char * log_prefix) { strncpy(_log_prefix, log_prefix, kBufferCount - 1); - if (sub_proc_log_prefix != nullptr) { - strncpy(_sub_proc_log_prefix, sub_proc_log_prefix, kBufferCount - 1); - } - _begin_cycles = HAP_perf_get_qtimer_count(); _begin_pcycles = HAP_perf_get_pcycles(); } @@ -180,61 +193,121 @@ template class npu_scoped_timer { void operator=(npu_scoped_timer && other) { strncpy(_log_prefix, other._log_prefix, kBufferCount - 1); - strncpy(_sub_proc_log_prefix, other._sub_proc_log_prefix, kBufferCount - 1); - _begin_cycles = other._begin_cycles; - _sub_proc_cycles = other._sub_proc_cycles; - _sub_proc_count = other._sub_proc_count; + _begin_cycles = other._begin_cycles; + _begin_pcycles = other._begin_pcycles; + memcpy(&_sub_proc_data, &other._sub_proc_data, sizeof(_sub_proc_data)); } - void add_sub_proc_cycles(uint64_t cycles, uint64_t pcycles) { - _sub_proc_cycles += cycles; - _sub_proc_pcycles += pcycles; - _sub_proc_count++; + void add_sub_proc_cycles(size_t sub_proc_idx, const char * sub_proc_prefix, uint64_t cycles, uint64_t pcycles) { + auto & sub_proc_data = _sub_proc_data[sub_proc_idx]; + sub_proc_data.proc_cycles += cycles; + sub_proc_data.proc_pcycles += pcycles; + + if (!sub_proc_data.proc_count) { + strncpy(sub_proc_data.log_prefix, sub_proc_prefix, sizeof(sub_proc_data.log_prefix) - 1); + } + + sub_proc_data.proc_count++; } void print() const { + static_assert(kSubProcCount == 4, "Sub process count must be 4 for logging format"); + auto total_cycles = HAP_perf_get_qtimer_count() - _begin_cycles; auto total_pcycles = HAP_perf_get_pcycles() - _begin_pcycles; auto duration = HAP_perf_qtimer_count_to_us(total_cycles); - if (_sub_proc_count > 0) { - auto sub_proc_duration = HAP_perf_qtimer_count_to_us(_sub_proc_cycles); - DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, pcyc: %llu, dur: %lluus\n", - _log_prefix, total_pcycles, duration, _sub_proc_log_prefix, _sub_proc_count, - _sub_proc_pcycles, sub_proc_duration); - } else { - DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", _log_prefix, total_pcycles, duration); + int sub_proc_count = 0; + for (int i = kSubProcCount; i > 0; --i) { + if (_sub_proc_data[i - 1].proc_count > 0) { + sub_proc_count = i; + break; + } + } + + auto sub_proc0_duration = HAP_perf_qtimer_count_to_us(_sub_proc_data[0].proc_cycles); + auto sub_proc1_duration = HAP_perf_qtimer_count_to_us(_sub_proc_data[1].proc_cycles); + auto sub_proc2_duration = HAP_perf_qtimer_count_to_us(_sub_proc_data[2].proc_cycles); + auto sub_proc3_duration = HAP_perf_qtimer_count_to_us(_sub_proc_data[3].proc_cycles); + + switch (sub_proc_count) { + case 4: + DEVICE_LOG_WARN( + "[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, " + "[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, " + "[%s]cnt: %llu, dur: %lluus\n", + _log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration, + _sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count, + (unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix, + (unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration, + _sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count, + (unsigned long long) sub_proc2_duration, _sub_proc_data[3].log_prefix, + (unsigned long long) _sub_proc_data[3].proc_count, (unsigned long long) sub_proc3_duration); + break; + case 3: + DEVICE_LOG_WARN( + "[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, " + "[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", + _log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration, + _sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count, + (unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix, + (unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration, + _sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count, + (unsigned long long) sub_proc2_duration); + break; + case 2: + DEVICE_LOG_WARN( + "[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, " + "[%s]cnt: %llu, dur: %lluus\n", + _log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration, + _sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count, + (unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix, + (unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration); + break; + case 1: + DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", _log_prefix, + (unsigned long long) total_pcycles, (unsigned long long) duration, + _sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count, + (unsigned long long) sub_proc0_duration); + break; + default: + case 0: + DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", _log_prefix, + (unsigned long long) total_pcycles, (unsigned long long) duration); + break; } } private: - char _log_prefix[kBufferCount] = {}; - char _sub_proc_log_prefix[kBufferCount] = {}; - uint64_t _begin_cycles = 0; - uint64_t _begin_pcycles = 0; - uint64_t _sub_proc_cycles = 0; - uint64_t _sub_proc_pcycles = 0; - uint64_t _sub_proc_count = 0; + char _log_prefix[kBufferCount] = {}; + uint64_t _begin_cycles = 0; + uint64_t _begin_pcycles = 0; + sub_process_data _sub_proc_data[kSubProcCount] = {}; DISABLE_COPY(npu_scoped_timer); }; -template class npu_sub_process_scoped_timer { +template class npu_sub_process_scoped_timer { public: + static_assert(_sub_idx < npu_scoped_timer<_buffer_count>::kSubProcCount, + "Sub process index must be less than kSubProcCount"); using npu_scoped_timer = npu_scoped_timer<_buffer_count>; - explicit npu_sub_process_scoped_timer(npu_scoped_timer & timer) : _timer(timer) { + explicit npu_sub_process_scoped_timer(npu_scoped_timer & timer, const char * prefix) : + _timer(timer), + _prefix(prefix) { _begin_cycles = HAP_perf_get_qtimer_count(); _begin_pcycles = HAP_perf_get_pcycles(); } ~npu_sub_process_scoped_timer() { - _timer.add_sub_proc_cycles(HAP_perf_get_qtimer_count() - _begin_cycles, + _timer.add_sub_proc_cycles(_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles, HAP_perf_get_pcycles() - _begin_pcycles); } private: npu_scoped_timer & _timer; + const char * _prefix = nullptr; uint64_t _begin_cycles = 0; uint64_t _begin_pcycles = 0; @@ -244,10 +317,10 @@ template class npu_sub_process_scoped_timer { inline auto make_scoped_perf_timer(const char * format, ...) { va_list args; va_start(args, format); - char buffer[512]; + char buffer[1024]; vsnprintf(buffer, sizeof(buffer), format, args); va_end(args); - return npu_scoped_timer<512>(buffer, nullptr); + return npu_scoped_timer<1024>(buffer); } #endif diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.cpp b/ggml/src/ggml-qnn/npu/device/vec_ops.cpp new file mode 100644 index 0000000000..5bdf183e5b --- /dev/null +++ b/ggml/src/ggml-qnn/npu/device/vec_ops.cpp @@ -0,0 +1,156 @@ +#include "vec_ops.hpp" + +#include + +#include "util.hpp" + +namespace { + +template +inline float vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem); + + HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); + HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector; + HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); + HVX_Vector prev0 = *src0_vec_ptr++; + HVX_Vector prev1 = *src1_vec_ptr++; + HVX_Vector sum = Q6_V_vzero(); + HVX_Vector sum0 = Q6_V_vzero(); + HVX_Vector sum1 = Q6_V_vzero(); + + while (src0_vec_ptr_end - src0_vec_ptr > 1) { + HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; + HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; + + HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0); + HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1); + HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0); + HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1); + prev0 = Q6_V_hi_W(curr0); + prev1 = Q6_V_hi_W(curr1); + src0_vec_ptr += 2; + src1_vec_ptr += 2; + + sum0 = _AddFunc(_MpyFunc(l0, l1), sum0); + sum1 = _AddFunc(_MpyFunc(h0, h1), sum1); + } + + sum = _AddFunc(sum0, sum1); + if (src0_vec_ptr_end - src0_vec_ptr > 0) { + HVX_Vector curr0 = *src0_vec_ptr++; + HVX_Vector curr1 = *src1_vec_ptr++; + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + prev0 = curr0; + prev1 = curr1; + + sum = _AddFunc(_MpyFunc(s0, s1), sum); + } + + const size_t leftover = count % kElementsPerVector; + if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) { + // handle the last vector + // see also: + // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 + // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c + bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr); + bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr); + HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0; + HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1; + src0_vec_ptr += should_fetch_src0 ? 1 : 0; + src1_vec_ptr += should_fetch_src1 ? 1 : 0; + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + prev0 = curr0; + prev1 = curr1; + + sum = _AddFunc(_MpyFunc(s0, s1), sum); + } + + const size_t leftover_bytes = leftover * sizeof(_TElem); + if (leftover > 0) { + // handle the leftover elements + HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ? + *src0_vec_ptr : + prev0; + curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + + HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ? + *src1_vec_ptr : + prev1; + curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + + sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes), sum); + } + + return _ReduceFunc(sum); +} + +template +inline float vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem); + + HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); + HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector; + HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); + HVX_Vector sum0 = Q6_V_vzero(); + HVX_Vector sum1 = Q6_V_vzero(); + + while (src0_vec_ptr_end - src0_vec_ptr > 1) { + HVX_Vector curr0_lo = src0_vec_ptr[0]; + HVX_Vector curr0_hi = src0_vec_ptr[1]; + HVX_Vector curr1_lo = src1_vec_ptr[0]; + HVX_Vector curr1_hi = src1_vec_ptr[1]; + src0_vec_ptr += 2; + src1_vec_ptr += 2; + + sum0 = _AddFunc(_MpyFunc(curr0_lo, curr1_lo), sum0); + sum1 = _AddFunc(_MpyFunc(curr0_hi, curr1_hi), sum1); + } + + return _ReduceFunc(_AddFunc(sum0, sum1)); +} + +inline HVX_Vector vec_mpy_qf32(HVX_Vector src0, HVX_Vector src1) { + return Q6_Vqf32_vmpy_VsfVsf(src0, src1); +} + +inline HVX_Vector vec_add_qf32(HVX_Vector sum, HVX_Vector result) { + return Q6_Vqf32_vadd_Vqf32Vqf32(sum, result); +} + +inline HVX_Vector vec_mpy_qf16(HVX_Vector src0, HVX_Vector src1) { + return Q6_Vqf16_vmpy_VhfVhf(src0, src1); +} + +inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) { + return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result); +} + +} // namespace + +namespace hexagon { + +float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) { + return vec_dot_product_impl(src0, src1, count); +} + +float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) { + return vec_dot_product_aligned_impl(src0, src1, + count); +} + +float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) { + return vec_dot_product_impl( + src0, src1, count); +} + +float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) { + return vec_dot_product_aligned_impl( + src0, src1, count); +} + +} // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.hpp b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp new file mode 100644 index 0000000000..406075fc9c --- /dev/null +++ b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp @@ -0,0 +1,274 @@ +#pragma once + +#include +#include + +#include + +#include "hexagon_npu.h" + +namespace hexagon { + +constexpr const size_t kBytesPerVector = sizeof(HVX_Vector); // 128 for v73 +constexpr const size_t kAlignMask = kBytesPerVector - 1; + +inline size_t unaligned_bytes(const void * addr) { + return ((size_t) addr) & kAlignMask; +} + +inline bool is_addr_aligned(const void * addr) { + return unaligned_bytes(addr) == 0; +} + +inline float get_flt0_from_fltv(HVX_Vector vect) { + static_assert(sizeof(vect[0]) == sizeof(float), "vect[0] should be a float"); + int32_t i = vect[0]; + return reinterpret_cast(i); +} + +inline HVX_UVector Q6_V_vmemu_R(const void * unaligned_ptr) { + return *reinterpret_cast(unaligned_ptr); +} + +inline HVX_Vector Q6_V_vmem_R(const void * aligned_ptr) { + return *reinterpret_cast(aligned_ptr); +} + +constexpr const size_t kL2CacheSize = 8 * 1024; // // 8KB L2 cache +constexpr const size_t kL2FetchAheadVectors = kL2CacheSize / kBytesPerVector; + +inline void l2fetch(const void * p, uint32_t stride, uint32_t width, uint32_t height, uint32_t dir) { + uint64_t control = HEXAGON_V64_CREATE_H(dir, stride, width, height); + __asm__ __volatile__(" l2fetch(%0,%1) " : : "r"(p), "r"(control)); +} + +inline void l2fetch_row(const uint8_t * row_ptr, size_t bytes) { + // TODO: should we use small kL2FetchAheadVectors? + int32_t l2fetch_vectors = Q6_R_min_RR(bytes / kBytesPerVector, kL2FetchAheadVectors); + hexagon::l2fetch(row_ptr, kBytesPerVector, kBytesPerVector, l2fetch_vectors, 0); +} + +/* + * This function converts a vector of IEEE float elements to a vector of qf32 elements + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_Vector qhmath_hvx_vqf32_convert_vsf(HVX_Vector vin) { + return Q6_Vqf32_vadd_VsfVsf(vin, Q6_V_vzero()); +} + +/* + * This function converts a vector of IEEE half float elements to a vector of qf16 elements + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_Vector qhmath_hvx_vqf16_convert_vhf(HVX_Vector vin) { + return Q6_Vqf16_vadd_VhfVhf(vin, Q6_V_vzero()); +} + +/* + * This function converts a pair of vectors of qf32 elements to a vector of IEEE half float elements + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_Vector qhmath_hvx_vhf_convert_vqf32(HVX_VectorPair vin_vp) { + return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(vin_vp)); +} + +/* + * This function converts a vector of qf16 elements to a pair of vectors of qf32 elements + * See also: libs\qfe\inc\qhmath_hvx_convert.h + */ +inline HVX_VectorPair qhmath_hvx_vqf32_convert_vqf16(HVX_Vector vxl) { + HVX_VectorPair vxw_vp, exponent_vp; + HVX_Vector mantissa_mask = Q6_Vh_vsplat_R(0xffe0); + HVX_Vector exp_mask = Q6_Vh_vsplat_R(0x1f); + HVX_Vector exp_offset = Q6_Vh_vsplat_R(0x70); + HVX_Vector mant32_shift = Q6_Vh_vsplat_R(0x10); + HVX_Vector reql, reqh, vxl_w, vxh_w, mantissa; + HVX_Vector el_exponent, eh_exponent; + + el_exponent = Q6_V_vand_VV(exp_mask, vxl); + // Obtain the mantissa part: bits (5-15) + mantissa = Q6_V_vand_VV(mantissa_mask, vxl); + // Convert qf16 biassed exponent to qf32 biased exponent + // new exp = exp + ( 127 (qf32 bias) -15(qf16 biass) ) = 112 + el_exponent = Q6_Vh_vadd_VhVh(exp_offset, el_exponent); + + vxw_vp = Q6_Ww_vunpack_Vh(mantissa); + vxl_w = Q6_V_lo_W(vxw_vp); + vxh_w = Q6_V_hi_W(vxw_vp); + + exponent_vp = Q6_Ww_vunpack_Vh(el_exponent); + el_exponent = Q6_V_lo_W(exponent_vp); + eh_exponent = Q6_V_hi_W(exponent_vp); + // Convert q16 mantiss to q32 mantissa + reql = Q6_Vw_vasl_VwVw(vxl_w, mant32_shift); + reqh = Q6_Vw_vasl_VwVw(vxh_w, mant32_shift); + // Add the exponent + vxl_w = Q6_Vw_vadd_VwVw(reql, el_exponent); + vxh_w = Q6_Vw_vadd_VwVw(reqh, eh_exponent); + + return Q6_W_vcombine_VV(vxh_w, vxl_w); +} + +inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) { + constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float); + static_assert(kFloatsPerVector == 32 || kFloatsPerVector == 16, "kFloatsPerVector should be 16 or 32"); + + // TODO: do we have a better way to do the reduction? + switch (kFloatsPerVector) { + default: + case 32: + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 16 * sizeof(float))); + // fallthrough + case 16: + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 8 * sizeof(float))); + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 4 * sizeof(float))); + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 2 * sizeof(float))); + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, sizeof(float))); + break; + } + + return sums; +} + +inline float vec_reduction_qf32_f32(HVX_Vector sums) { + return get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec_reduction_qf32(sums))); +} + +inline HVX_Vector vec_reduction_qf16(HVX_Vector sums) { + constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(npu_device_fp16_t); + static_assert(kFloatsPerVector == 64 || kFloatsPerVector == 32, "kFloatsPerVector should be 32 or 64"); + + // TODO: do we have a better way to do the reduction? + switch (kFloatsPerVector) { + default: + case 64: + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 32 * sizeof(npu_device_fp16_t))); + // fallthrough + case 32: + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 16 * sizeof(npu_device_fp16_t))); + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 8 * sizeof(npu_device_fp16_t))); + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 4 * sizeof(npu_device_fp16_t))); + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 2 * sizeof(npu_device_fp16_t))); + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, sizeof(npu_device_fp16_t))); + break; + } + + return sums; +} + +inline float vec_reduction_qf16_f32(HVX_Vector sums) { + HVX_Vector vect = Q6_Vhf_equals_Vqf16(vec_reduction_qf16(sums)); + uint16_t i = (vect[0] & 0xffff); + return reinterpret_cast<__fp16 &>(i); +} + +inline HVX_Vector hvx_scale_f32(float scale) { + return Q6_V_vsplat_R(reinterpret_cast(scale)); +} + +template +inline void vec_scale_impl(const _TParam * src, float scale, _TParam * dst, size_t count) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TParam); + + HVX_Vector * src_vec_ptr = ((HVX_Vector *) src); + HVX_Vector * const src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector); + HVX_UVector * dst_vec_ptr = ((HVX_UVector *) dst); // TODO: opt the unaligned case? + HVX_Vector prev = *src_vec_ptr++; + const size_t leftover = count % kElementsPerVector; + const size_t leftover_bytes = leftover * sizeof(_TParam); + + HVX_Vector scale_vec = _FuncScaleConvert(scale); + + while (src_vec_end - src_vec_ptr > 1) { + HVX_VectorPair curr = reinterpret_cast(src_vec_ptr)[0]; + src_vec_ptr += 2; + + HVX_Vector lo = Q6_V_valign_VVR(Q6_V_lo_W(curr), prev, (size_t) src); + HVX_Vector hi = Q6_V_valign_VVR(Q6_V_hi_W(curr), Q6_V_lo_W(curr), (size_t) src); + + dst_vec_ptr[0] = _Func(lo, dst_vec_ptr, scale_vec); + dst_vec_ptr[1] = _Func(hi, dst_vec_ptr + 1, scale_vec); + + dst_vec_ptr += 2; + prev = Q6_V_hi_W(curr); + } + + if (src_vec_end - src_vec_ptr > 0) { + HVX_Vector curr = *src_vec_ptr++; + HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); + dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec); + dst_vec_ptr++; + prev = curr; + } + + if ((src_vec_end - ((HVX_Vector *) src)) > 0) { + // handle the last vector + bool should_fetch_next = leftover == 0 && hexagon::is_addr_aligned(src_vec_ptr); + HVX_Vector curr = should_fetch_next ? prev : *src_vec_ptr; + src_vec_ptr = should_fetch_next ? src_vec_ptr : src_vec_ptr + 1; + HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); + dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec); + dst_vec_ptr++; + prev = curr; + } + + if (leftover > 0) { + // handle the leftover elements + HVX_Vector curr = + (leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev; + curr = Q6_V_valign_VVR(curr, prev, (size_t) src); + q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _Func(curr, dst_vec_ptr, scale_vec)); + } +} + +inline HVX_Vector hvx_vec_scale_f32_f32(HVX_Vector src, HVX_UVector *, HVX_Vector scale_vec) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(src, scale_vec)); +} + +inline HVX_Vector hvx_vec_mad_f32_f32(HVX_Vector src, HVX_UVector * dst_ptr, HVX_Vector scale_vec) { + HVX_Vector dst = *dst_ptr; // TODO: opt the unaligned case? + src = Q6_Vqf32_vmpy_VsfVsf(src, scale_vec); + src = Q6_Vqf32_vadd_Vqf32Vsf(src, dst); + return Q6_Vsf_equals_Vqf32(src); +} + +inline void vec_scale_f32(const float * src, float scale, float * dst, size_t count) { + vec_scale_impl(src, scale, dst, count); +} + +inline void vec_mad_f32(const float * src, float scale, float * dst, size_t count) { + vec_scale_impl(src, scale, dst, count); +} + +inline HVX_Vector hvx_scale_f16(float scale) { + __fp16 f16_scale = scale; + return Q6_Vh_vsplat_R(reinterpret_cast(f16_scale)); +} + +inline HVX_Vector hvx_vec_scale_f16_f16(HVX_Vector src, HVX_UVector *, HVX_Vector scale_vec) { + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(src, scale_vec)); +} + +inline HVX_Vector hvx_vec_mad_f16_f16(HVX_Vector src, HVX_UVector * dst_ptr, HVX_Vector scale_vec) { + HVX_Vector dst = *dst_ptr; // TODO: opt the unaligned case? + HVX_Vector scaled = Q6_Vqf16_vmpy_VhfVhf(src, scale_vec); + HVX_Vector result = Q6_Vqf16_vadd_Vqf16Vhf(scaled, dst); + return Q6_Vhf_equals_Vqf16(result); +} + +inline void vec_scale_f16(const npu_device_fp16_t * src, float scale, npu_device_fp16_t * dst, size_t count) { + vec_scale_impl(src, scale, dst, count); +} + +inline void vec_mad_f16(const npu_device_fp16_t * src, float scale, npu_device_fp16_t * dst, size_t count) { + vec_scale_impl(src, scale, dst, count); +} + +float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count); +float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count); + +float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count); +float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count); + +} // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/host/buffer.cpp b/ggml/src/ggml-qnn/npu/host/buffer.cpp index 7d3c1fbd9f..c7482f8b59 100644 --- a/ggml/src/ggml-qnn/npu/host/buffer.cpp +++ b/ggml/src/ggml-qnn/npu/host/buffer.cpp @@ -114,7 +114,9 @@ size_t backend_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { size_t backend_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { auto * buffer_type_obj = get_buffer_type_object(buft); GGML_ASSERT(buffer_type_obj != nullptr); - return buffer_type_obj->get_max_buffer_size(); + auto size = buffer_type_obj->get_max_buffer_size(); + LOG_DEBUG("[hexagon-npu][%s]max_buffer_size: %zu\n", buffer_type_obj->get_name(), size); + return size; } bool backend_buffer_is_host(ggml_backend_buffer_type_t buft) { diff --git a/ggml/src/ggml-qnn/npu/host/graph.cpp b/ggml/src/ggml-qnn/npu/host/graph.cpp index d891280e56..9ac69924d3 100644 --- a/ggml/src/ggml-qnn/npu/host/graph.cpp +++ b/ggml/src/ggml-qnn/npu/host/graph.cpp @@ -29,6 +29,8 @@ bool host_graph::update(ggml_cgraph * cgraph) { return false; } + LOG_DEBUG("[%p]host_graph::update started\n", (void *) this); + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]update, handle(%p)", (void *) this, (void *) _graph_handle); _tensor_handles.clear(); @@ -40,8 +42,9 @@ bool host_graph::update(ggml_cgraph * cgraph) { if (node->op == GGML_OP_NONE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_RESHAPE) { // skip view liked ops - LOG_DEBUG("node[%d]%s(%s), addr: %p, type: %s, skipped\n", i, ggml_get_name(node), ggml_op_desc(node), - (void *) node, ggml_type_name(node->type)); + LOG_DEBUG("node[%d]%s(%s), addr: %p, type: %s, dims: %ldx%ldx%ldx%ld, skipped\n", i, ggml_get_name(node), + ggml_op_desc(node), (void *) node, ggml_type_name(node->type), (long) node->ne[0], + (long) node->ne[1], (long) node->ne[2], (long) node->ne[3]); continue; } @@ -54,9 +57,10 @@ bool host_graph::update(ggml_cgraph * cgraph) { _tensor_handles.push_back(tensor_obj->get_device_tensor_handle()); _tensor_update_configs.push_back(tensor_obj->update_hosts_params_only(node)); - LOG_DEBUG("[%p]node[%d]%s(%s), addr: %p, type: %s, tensor_handle: %p\n", (void *) this, i, ggml_get_name(node), - ggml_op_desc(node), (void *) node, ggml_type_name(node->type), - (void *) tensor_obj->get_device_tensor_handle()); + LOG_DEBUG("node[%d]%s(%s), addr: %p, type: %s, dims: %ldx%ldx%ldx%ld, tensor_handle: %p\n", i, + ggml_get_name(node), ggml_op_desc(node), (void *) node, ggml_type_name(node->type), + (long) tensor_obj->get_ne(0), (long) tensor_obj->get_ne(1), (long) tensor_obj->get_ne(2), + (long) tensor_obj->get_ne(3), (void *) tensor_obj->get_device_tensor_handle()); } GGML_ASSERT(_tensor_handles.size() == _tensor_update_configs.size()); @@ -71,7 +75,7 @@ bool host_graph::update(ggml_cgraph * cgraph) { (int) _tensor_update_configs.size()); if (ret != AEE_SUCCESS) { - LOG_ERROR("Failed to set tensors in host_graph: 0x%x\n", (int) ret); + LOG_ERROR("[%p]failed to set tensors in host_graph: 0x%x\n", (void *) this, (int) ret); return false; } diff --git a/ggml/src/ggml-qnn/npu/host/host_device.cpp b/ggml/src/ggml-qnn/npu/host/host_device.cpp index 443abe5c9e..e88ef002bf 100644 --- a/ggml/src/ggml-qnn/npu/host/host_device.cpp +++ b/ggml/src/ggml-qnn/npu/host/host_device.cpp @@ -149,37 +149,17 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) { return false; } - auto * src0 = op->src[0]; - if (!src0) { - LOG_DEBUG("[%s]Unsupported inplace op: %s\n", get_name(), ggml_op_desc(op)); - return false; - } - - if (type_to_npu_type(src0->type) == NPU_DATA_TYPE_COUNT) { - LOG_DEBUG("[%s]Unsupported src0 tensor type: %s\n", get_name(), ggml_type_name(src0->type)); - return false; - } - - auto * src1 = op->src[1]; - if (src1 && type_to_npu_type(src1->type) == NPU_DATA_TYPE_COUNT) { - LOG_DEBUG("[%s]Unsupported src1 tensor type: %s\n", get_name(), ggml_type_name(src1->type)); - return false; - } - auto npu_op = op_to_npu_op(op->op); if (npu_op == NPU_OP_COUNT) { LOG_DEBUG("[%s]Unsupported op: %s\n", get_name(), ggml_op_desc(op)); return false; } - if (!_device_handle && !init_device()) { - LOG_DEBUG("[%s]NPU device initialization failed\n", get_name()); - return false; - } - - constexpr const auto get_spec = [](const ggml_tensor * tensor) -> npu_device_tensor_spec { + int i = 0; + npu_device_tensor_spec srcs[DEVICE_TENSOR_MAX_SRC] = {}; + constexpr const auto get_spec = [](const ggml_tensor * tensor) -> npu_device_tensor_spec { if (!tensor) { - return npu_device_tensor_spec{ {}, NPU_DATA_TYPE_COUNT }; + return npu_device_tensor_spec{ {}, {}, NPU_DATA_TYPE_COUNT }; } static_assert(DEVICE_TENSOR_MAX_DIMS == GGML_MAX_DIMS, "tensor dimensions mismatch"); @@ -188,19 +168,40 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) { spec.ne[1] = tensor->ne[1]; spec.ne[2] = tensor->ne[2]; spec.ne[3] = tensor->ne[3]; + + spec.nb[0] = tensor->nb[0]; + spec.nb[1] = tensor->nb[1]; + spec.nb[2] = tensor->nb[2]; + spec.nb[3] = tensor->nb[3]; spec.type = type_to_npu_type(tensor->type); return spec; }; + for (; i < (int) DEVICE_TENSOR_MAX_SRC && op->src[i]; ++i) { + auto * src = op->src[i]; + if (type_to_npu_type(src->type) == NPU_DATA_TYPE_COUNT) { + LOG_DEBUG("[%s]Unsupported src%d tensor type: %s\n", get_name(), i, ggml_type_name(src->type)); + return false; + } + + srcs[i] = get_spec(src); + } + + if (!_device_handle && !init_device()) { + LOG_DEBUG("[%s]NPU device initialization failed\n", get_name()); + return false; + } + boolean supported = false; - auto src0_spec = get_spec(src0); - auto src1_spec = get_spec(src1); auto dst_spec = get_spec(op); - auto ret = npu_device_device_support_op(_device_handle, &src0_spec, &src1_spec, &dst_spec, npu_op, &supported); + auto ret = npu_device_device_support_op(_device_handle, npu_op, &dst_spec, srcs, i, &supported); if (ret != AEE_SUCCESS || !supported) { +#ifndef NDEBUG + auto * src0_type = i ? ggml_type_name(op->src[0]->type) : "null"; + auto * src1_type = (i > 1) ? ggml_type_name(op->src[1]->type) : "null"; LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", get_name(), ggml_op_name(op->op), - ggml_type_name(op->type), ggml_type_name(src0->type), (src1 ? ggml_type_name(src1->type) : "null"), - ret, supported); + ggml_type_name(op->type), src0_type, src1_type, ret, supported); +#endif return false; } diff --git a/ggml/src/ggml-qnn/npu/host/tensor.hpp b/ggml/src/ggml-qnn/npu/host/tensor.hpp index 71205b39fb..07e092049c 100644 --- a/ggml/src/ggml-qnn/npu/host/tensor.hpp +++ b/ggml/src/ggml-qnn/npu/host/tensor.hpp @@ -21,15 +21,17 @@ class host_tensor { explicit host_tensor(ggml_tensor * tensor, int buffer_fd, uint64_t offset, remote_handle64 device_handle) : _device_handle(device_handle) { - // TODO: figure out why the npu_device_tensor_config can't be larger than 100 bytes - static_assert(sizeof(npu_device_tensor_config) < 100, "npu_device_tensor_config size too large"); + static_assert(sizeof(npu_device_tensor_config) < kMaxNpuRpcStructSize, + "npu_device_tensor_config size too large"); - _info.buffer_fd = buffer_fd; - _info.offset = offset; - _info.type = type_to_npu_type(tensor->type); - _info.size = ggml_nbytes(tensor); + _info.buffer_fd = buffer_fd; + _info.offset = offset; + _info.type = type_to_npu_type(tensor->type); + _info.size = ggml_nbytes(tensor); + _info.is_constant = false; // TODO: support constant tensors in the future // _info.op will be updated in update_params() + _info_update.op = NPU_OP_COUNT; static_assert(DEVICE_TENSOR_MAX_DIMS == GGML_MAX_DIMS, "tensor dimensions mismatch"); static_assert(sizeof(_info.ne) == sizeof(tensor->ne), "tensor ne size mismatch"); @@ -46,10 +48,10 @@ class host_tensor { tensor->extra = this; _ggml_tensor = tensor; - LOG_DEBUG("host_tensor(%p), ggml_tensor(%p[%ldx%ldx%ldx%ld], nb[%ld][%ld][%ld][%ld], %s), handle(%p)\n", - (void *) this, (void *) tensor, (long) tensor->ne[0], (long) tensor->ne[1], (long) tensor->ne[2], + LOG_DEBUG("host_tensor(%p), ggml_tensor(%s[%ldx%ldx%ldx%ld], nb[%ld][%ld][%ld][%ld], %s, %p), handle(%p)\n", + (void *) this, tensor->name, (long) tensor->ne[0], (long) tensor->ne[1], (long) tensor->ne[2], (long) tensor->ne[3], (long) tensor->nb[0], (long) tensor->nb[1], (long) tensor->nb[2], - (long) tensor->nb[3], ggml_type_name(tensor->type), (void *) _device_tensor_handle); + (long) tensor->nb[3], ggml_type_name(tensor->type), (void *) tensor, (void *) _device_tensor_handle); } ~host_tensor() { @@ -76,11 +78,9 @@ class host_tensor { auto new_op = op_to_npu_op(_ggml_tensor->op); bool params_changed = new_op != _info_update.op; if (params_changed) { - LOG_DEBUG("host_tensor(%p) op changed: %s -> %s\n", (void *) this, get_npu_op_desc(_info.op), - get_npu_op_desc(new_op)); + LOG_DEBUG("host_tensor(%p) op changed: %s\n", (void *) this, get_npu_op_desc(new_op)); } - _info.op = new_op; _info_update.op = new_op; if (memcmp(_info_update.params, _ggml_tensor->op_params, sizeof(_info_update.params)) != 0) { @@ -92,15 +92,16 @@ class host_tensor { } npu_device_tensor_handle_t src_tensor_handles[DEVICE_TENSOR_MAX_SRC] = {}; - for (size_t j = 0; j < DEVICE_TENSOR_MAX_SRC && _ggml_tensor->src[j]; ++j) { - auto * src = host_tensor::from_ggml_tensor(_ggml_tensor->src[j]); - src_tensor_handles[j] = src->get_device_tensor_handle(); - LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p\n", (void *) this, j, (void *) src); - } - static_assert(std::is_same::value, "src tensor handles type mismatch"); + for (size_t j = 0; j < DEVICE_TENSOR_MAX_SRC && _ggml_tensor->src[j]; ++j) { + auto * ggml_src = _ggml_tensor->src[j]; + auto * src = host_tensor::from_ggml_tensor(ggml_src); + src_tensor_handles[j] = src->get_device_tensor_handle(); + LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p(%s)\n", (void *) this, j, (void *) src, ggml_src->name); + } + if (memcmp(_info_update.src_handles, src_tensor_handles, sizeof(_info_update.src_handles)) != 0) { params_changed = true; memcpy(_info_update.src_handles, src_tensor_handles, sizeof(_info_update.src_handles)); @@ -128,14 +129,14 @@ class host_tensor { GGML_ASSERT(ggml_tensor == _ggml_tensor); auto new_op = op_to_npu_op(_ggml_tensor->op); - _info.op = new_op; _info_update.op = new_op; memcpy(_info_update.params, _ggml_tensor->op_params, sizeof(_info_update.params)); for (size_t j = 0; j < DEVICE_TENSOR_MAX_SRC && _ggml_tensor->src[j]; ++j) { - auto * src = host_tensor::from_ggml_tensor(_ggml_tensor->src[j]); + auto * ggml_src = _ggml_tensor->src[j]; + auto * src = host_tensor::from_ggml_tensor(ggml_src); _info_update.src_handles[j] = src->get_device_tensor_handle(); - LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p\n", (void *) this, j, (void *) src); + LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p(%s)\n", (void *) this, j, (void *) src, ggml_src->name); } LOG_DEBUG("host_tensor(%p) update_params, op: %s, params: [%x, %x, %x, %x]\n", (void *) this, @@ -146,6 +147,15 @@ class host_tensor { bool is_valid() const { return _device_tensor_handle != 0; } + int64_t get_ne(size_t index) const { + if (index >= DEVICE_TENSOR_MAX_DIMS) { + LOG_ERROR("host_tensor(%p) get_ne: index out of bounds: %zu\n", (void *) this, index); + return 0; + } + + return _info.ne[index]; + } + private: remote_handle64 _device_handle = 0; npu_device_tensor_handle_t _device_tensor_handle = 0; diff --git a/ggml/src/ggml-qnn/npu/host/util.cpp b/ggml/src/ggml-qnn/npu/host/util.cpp index b62370d1ad..0b00512333 100644 --- a/ggml/src/ggml-qnn/npu/host/util.cpp +++ b/ggml/src/ggml-qnn/npu/host/util.cpp @@ -6,7 +6,7 @@ #include "ggml-common.h" #undef GGML_COMMON_DECL_CPP -static_assert(sizeof(npu_device_block_q4_K) == sizeof(block_q4_K), "npu_device_block_q4_K size mismatch"); +static_assert(sizeof(npu_device_block_q4_k) == sizeof(block_q4_K), "npu_device_block_q4_k size mismatch"); static_assert(sizeof(npu_device_block_q4_0) == sizeof(block_q4_0), "npu_device_block_q4_0 size mismatch"); static_assert(sizeof(npu_device_block_q8_0) == sizeof(block_q8_0), "npu_device_block_q8_0 size mismatch"); static_assert(QUANT_K_SCALE_SIZE == K_SCALE_SIZE, "QUANT_K_SCALE_SIZE size mismatch"); @@ -27,6 +27,8 @@ enum npu_device_tensor_op op_to_npu_op(ggml_op op) { return NPU_OP_MUL; case GGML_OP_RMS_NORM: return NPU_OP_RMS_NORM; + case GGML_OP_FLASH_ATTN_EXT: + return NPU_OP_FLASH_ATTN; default: return NPU_OP_COUNT; } @@ -44,6 +46,8 @@ const char * get_npu_op_desc(enum npu_device_tensor_op op) { return ggml_op_name(GGML_OP_MUL); case NPU_OP_RMS_NORM: return ggml_op_name(GGML_OP_RMS_NORM); + case NPU_OP_FLASH_ATTN: + return ggml_op_name(GGML_OP_FLASH_ATTN_EXT); default: return "UNKNOWN"; } @@ -160,27 +164,65 @@ void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len) { } }; - auto * src0 = dst->src[0]; - if (src0 == nullptr) { - print_tensor(dst, out, max_len); - return; - } + constexpr const auto get_src_tensor_count = [](const ggml_tensor * tensor) -> size_t { + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { + if (!tensor->src[i]) { + return i; + } + } + + return GGML_MAX_SRC; + }; char dst_desc[256]; print_tensor(dst, dst_desc, sizeof(dst_desc)); - - char src0_desc[256]; - print_tensor(src0, src0_desc, sizeof(src0_desc)); - - auto * src1 = dst->src[1]; - if (src1 == nullptr) { - snprintf(out, max_len, "dst: %s, src0: %s", dst_desc, src0_desc); - return; + switch (get_src_tensor_count(dst)) { + case 4: + { + char src0_desc[256]; + print_tensor(dst->src[0], src0_desc, sizeof(src0_desc)); + char src1_desc[256]; + print_tensor(dst->src[1], src1_desc, sizeof(src1_desc)); + char src2_desc[256]; + print_tensor(dst->src[2], src2_desc, sizeof(src2_desc)); + char src3_desc[256]; + print_tensor(dst->src[3], src3_desc, sizeof(src3_desc)); + snprintf(out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s, src3: %s", dst_desc, src0_desc, + src1_desc, src2_desc, src3_desc); + return; + } + case 3: + { + char src0_desc[256]; + print_tensor(dst->src[0], src0_desc, sizeof(src0_desc)); + char src1_desc[256]; + print_tensor(dst->src[1], src1_desc, sizeof(src1_desc)); + char src2_desc[256]; + print_tensor(dst->src[2], src2_desc, sizeof(src2_desc)); + snprintf(out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s", dst_desc, src0_desc, src1_desc, + src2_desc); + return; + } + case 2: + { + char src0_desc[256]; + print_tensor(dst->src[0], src0_desc, sizeof(src0_desc)); + char src1_desc[256]; + print_tensor(dst->src[1], src1_desc, sizeof(src1_desc)); + snprintf(out, max_len, "dst: %s, src0: %s, src1: %s", dst_desc, src0_desc, src1_desc); + return; + } + case 1: + { + char src0_desc[256]; + print_tensor(dst->src[0], src0_desc, sizeof(src0_desc)); + snprintf(out, max_len, "dst: %s, src0: %s", dst_desc, src0_desc); + return; + } + default: + snprintf(out, max_len, "dst: %s", dst_desc); + return; } - - char src1_desc[256]; - print_tensor(src1, src1_desc, sizeof(src1_desc)); - snprintf(out, max_len, "dst: %s, src0: %s, src1: %s", dst_desc, src0_desc, src1_desc); } } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/host/util.hpp b/ggml/src/ggml-qnn/npu/host/util.hpp index f8ec5c3b9f..b4c2355cac 100644 --- a/ggml/src/ggml-qnn/npu/host/util.hpp +++ b/ggml/src/ggml-qnn/npu/host/util.hpp @@ -26,4 +26,6 @@ void enable_unsigned_dsp_module(common::rpc_interface_ptr rpc_interface, uint32_ void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len); +constexpr const size_t kMaxNpuRpcStructSize = 100; // TODO: figure out the actual size + } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl b/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl index ed20c125b3..70626c90cb 100644 --- a/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl +++ b/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl @@ -3,7 +3,7 @@ #include "remote.idl" const uint32_t DEVICE_TENSOR_MAX_DIMS = 4; -const uint32_t DEVICE_TENSOR_MAX_SRC = 2; +const uint32_t DEVICE_TENSOR_MAX_SRC = 4; const uint32_t DEVICE_TENSOR_MAX_OP_PARAMS = 4; const uint32_t QUANT_BLOCK_SIZE = 32; const uint32_t QUANT_K_BLOCK_SIZE = 256; @@ -12,6 +12,7 @@ const uint32_t QUANT_K_SCALE_SIZE = 12; interface npu_device : remote_handle64{ typedef int64_t ne_type[DEVICE_TENSOR_MAX_DIMS]; + typedef uint64_t nb_type[DEVICE_TENSOR_MAX_DIMS]; typedef uint64_t tensor_handle_t; typedef uint64_t graph_handle_t; @@ -22,7 +23,7 @@ interface npu_device : remote_handle64{ uint8_t qs[QUANT_BLOCK_SIZE / 2]; }; - struct block_q4_K { + struct block_q4_k { fp16_t d; fp16_t dmin; uint8_t scales[QUANT_K_SCALE_SIZE]; @@ -40,6 +41,7 @@ interface npu_device : remote_handle64{ NPU_OP_SUB, NPU_OP_MUL, NPU_OP_RMS_NORM, + NPU_OP_FLASH_ATTN, NPU_OP_COUNT }; @@ -54,6 +56,7 @@ interface npu_device : remote_handle64{ struct tensor_spec { ne_type ne; + nb_type nb; tensor_data_type type; }; @@ -65,12 +68,12 @@ interface npu_device : remote_handle64{ struct tensor_config { ne_type ne; - uint64_t nb[DEVICE_TENSOR_MAX_DIMS]; + nb_type nb; long buffer_fd; uint64_t offset; uint64_t size; tensor_data_type type; - tensor_op op; + boolean is_constant; }; AEEResult device_get_alignment( @@ -78,10 +81,9 @@ interface npu_device : remote_handle64{ ); AEEResult device_support_op( - in tensor_spec src0, - in tensor_spec src1, - in tensor_spec dst, in tensor_op op, + in tensor_spec dst, + in sequence srcs, rout boolean is_supported ); diff --git a/ggml/src/ggml-qnn/shared/CMakeLists.txt b/ggml/src/ggml-qnn/shared/CMakeLists.txt index b08b2f07eb..c8f9cf7a84 100644 --- a/ggml/src/ggml-qnn/shared/CMakeLists.txt +++ b/ggml/src/ggml-qnn/shared/CMakeLists.txt @@ -20,6 +20,8 @@ if(GGML_QNN_ENABLE_HEXAGON_BACKEND) if(DEFINED ENV{QNN_SDK_PATH}) set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT}) message("found HEXAGON_SDK_ROOT, setting to ${HEXAGON_SDK_ROOT}") + elseif(EXISTS ${HEXAGON_SDK_ROOT}) + message("HEXAGON_SDK_ROOT: ${HEXAGON_SDK_ROOT}") else() message(FATAL_ERROR "HEXAGON_SDK_ROOT not defined") endif()