feat: perf opt gemv (#54)

* add GEMV implementation for matrix multiplication in hexagon

* refactor: optimize GEMV implementation for matrix multiplication in hexagon

* wip

* refactor: enhance caching mechanism in GEMV implementation for matrix multiplication

* wip

* refactor: streamline caching logic in GEMV implementation for matrix multiplication

* wip

* wip

* fix broadcase in flash_attn

* format

* refactor: optimize memory fetching in matrix multiplication implementations

* wip

* fix aligned gemv

* rename

* refactor: remove unused memory cache functions and initialize VTCM cache

* wip

* feat: add vector math functions for IEEE float and half float operations

* feat: add vec_silu_f32 and vec_silu_f16 functions for SiLU activation

* feat: implement GLU operation support in tensor processing

* feat: add GLU operation support and related enhancements in tensor processing

* wip

* wip

* wip

* feat: add qhmath_hvx_div_vf functions for f32 vector operations

* feat: add qhmath_hvx_div_vhf functions for f16 vector operations

* fix: reorder parameters in vector operation functions for consistency

* wip

* feat: enhance vector operations with parameterized transformations and improved GLU implementations

* wip

* fix: increase default stack size and correct thread parameter indexing in thread pool

* fix f16 div

* fix f32 div

* fix: update GLU vector operations to use explicit denominator calculation

* wip

* wip

* Refactor cacheability check for matrix multiplication to handle multiple source tensors

* Revert "fix: increase default stack size and correct thread parameter indexing in thread pool"

This reverts commit 40e3f0974dbb04051aa30b397a9a171c6dd32678.

* wip

* fix comments

* replace copy with memcpy
This commit is contained in:
nullname 2025-08-08 20:40:26 +08:00 committed by GitHub
parent 6260c3166d
commit 379bdeb18c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 2335 additions and 441 deletions

View File

@ -1,10 +1,4 @@
#include <AEEStdErr.h>
#include <HAP_compute_res.h>
#include <hexagon_types.h>
#include <memory>
#include "graph.hpp"
#include "hexagon_npu.h"
#include "op_impl.hpp"
@ -14,6 +8,12 @@
#include "type_traits.hpp"
#include "util.hpp"
#include <AEEStdErr.h>
#include <HAP_compute_res.h>
#include <hexagon_types.h>
#include <memory>
namespace {
struct npu_device_context {
@ -130,8 +130,12 @@ AEEResult npu_device_device_get_alignment(remote_handle64 _h, uint32_t * alignme
return AEE_SUCCESS;
}
AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op op, const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs, int srcsLen, boolean * is_supported) {
AEEResult npu_device_device_support_op(remote_handle64 _h,
const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
int srcsLen,
boolean * is_supported) {
NPU_UNUSED(_h);
if (!srcs || srcsLen <= 0 || !dst || !is_supported) {
@ -139,19 +143,21 @@ AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op
return AEE_EINVARGS;
}
*is_supported = hexagon::support_op(op, dst, srcs, srcsLen);
*is_supported = hexagon::support_op(op_spec, dst, srcs, srcsLen);
return AEE_SUCCESS;
}
AEEResult npu_device_tensor_init(remote_handle64 _h, const npu_device_tensor_config * info,
npu_device_tensor_handle_t * tensor_handle) {
AEEResult npu_device_tensor_init(remote_handle64 _h,
const npu_device_tensor_config * info,
npu_device_tensor_handle_t * tensor_handle) {
NPU_UNUSED(_h);
auto * tensor = new hexagon::tensor(*info);
*tensor_handle = tensor_to_handle(tensor);
return AEE_SUCCESS;
}
AEEResult npu_device_tensor_update_params(remote_handle64 _h, npu_device_tensor_handle_t tensor_handle,
AEEResult npu_device_tensor_update_params(remote_handle64 _h,
npu_device_tensor_handle_t tensor_handle,
const npu_device_tensor_update_config * config) {
NPU_UNUSED(_h);
auto * tensor = tensor_from_handle(tensor_handle);
@ -174,8 +180,9 @@ AEEResult npu_device_tensor_free(remote_handle64 _h, npu_device_tensor_handle_t
return AEE_SUCCESS;
}
AEEResult npu_device_tensors_free(remote_handle64 _h, const npu_device_tensor_handle_t * tensor_handles,
int tensor_handlesLen) {
AEEResult npu_device_tensors_free(remote_handle64 _h,
const npu_device_tensor_handle_t * tensor_handles,
int tensor_handlesLen) {
NPU_UNUSED(_h);
if (!tensor_handles || tensor_handlesLen < 0) {
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments");
@ -201,8 +208,10 @@ AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t *
return AEE_SUCCESS;
}
AEEResult npu_device_graph_set_tensor(remote_handle64 _h, npu_device_graph_handle_t graph_handle,
const npu_device_tensor_handle_t * tensor_handles, int tensor_handlesLen) {
AEEResult npu_device_graph_set_tensor(remote_handle64 _h,
npu_device_graph_handle_t graph_handle,
const npu_device_tensor_handle_t * tensor_handles,
int tensor_handlesLen) {
NPU_UNUSED(_h);
auto * graph = graph_from_handle(graph_handle);
if (!graph || !tensor_handles || tensor_handlesLen <= 0) {
@ -213,7 +222,8 @@ AEEResult npu_device_graph_set_tensor(remote_handle64 _h, npu_device_graph_handl
return AEE_SUCCESS;
}
AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h, npu_device_graph_handle_t graph_handle,
AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h,
npu_device_graph_handle_t graph_handle,
const npu_device_tensor_handle_t * tensor_handles,
int tensor_handlesLen,
const npu_device_tensor_update_config * tensor_params,

View File

@ -1,12 +1,12 @@
#include "graph.hpp"
#include <new>
#include "op_impl.hpp"
#include "util.hpp"
#include "vtcm_mem.hpp"
#include <new>
namespace hexagon {
graph::graph() noexcept {
@ -30,8 +30,12 @@ void graph::set_tensor(const npu_device_tensor_handle_t * tensors, int tensor_co
for (int i = 0; i < tensor_count; ++i) {
auto * tensor_obj = reinterpret_cast<tensor *>(tensors[i]);
_tensors[i] = tensor_obj;
DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n", (void *) this, i, (void *) tensor_obj,
(void *) tensor_obj->get_src(0), (void *) tensor_obj->get_src(1),
DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n",
(void *) this,
i,
(void *) tensor_obj,
(void *) tensor_obj->get_src(0),
(void *) tensor_obj->get_src(1),
op_get_name(tensor_obj->get_op()));
}
@ -64,8 +68,9 @@ bool graph::compute(default_thread_pool * thread_pool, const float * f16_to_f32_
return true;
}
void graph::thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
void * graph) {
void graph::thread_pool_task(default_thread_pool * pool,
default_thread_pool::thread_params * thread_params,
void * graph) {
reinterpret_cast<hexagon::graph *>(graph)->compute_impl(pool, thread_params);
}
@ -86,8 +91,11 @@ void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread
const bool should_sync = requires_thread_barrier(op);
if (pool && should_sync && i < _tensor_count - 1) {
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this,
params.get_thread_index(), i, _tensor_count);
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]",
(void *) this,
params.get_thread_index(),
i,
_tensor_count);
pool->sync_thread();
}
}

View File

@ -1,11 +1,11 @@
#pragma once
#include <memory>
#include "hexagon_npu.h"
#include "tensor.hpp"
#include "thread_pool.hpp"
#include <memory>
namespace hexagon {
class graph {
@ -20,8 +20,9 @@ class graph {
bool compute(default_thread_pool * thread_pool, const float * f16_to_f32_table);
private:
static void thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
void * graph);
static void thread_pool_task(default_thread_pool * pool,
default_thread_pool::thread_params * thread_params,
void * graph);
void compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params);
std::unique_ptr<tensor *[]> _tensors;

View File

@ -14,15 +14,20 @@ inline float f16_to_f32(const npu_device_fp16_t src) {
// From: ggml/src/ggml-cpu/ops.cpp
template <bool _IsKvF16>
void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k,
const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) {
void flash_attn_impl(hexagon::tensor * out,
const hexagon::tensor * q,
const hexagon::tensor * k,
const hexagon::tensor * v,
const hexagon::tensor * mask,
hexagon::compute_params * params) {
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");
constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32;
if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) {
DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n",
hexagon::get_type_name(k->get_type()), hexagon::get_type_name(v->get_type()));
hexagon::get_type_name(k->get_type()),
hexagon::get_type_name(v->get_type()));
return;
}
@ -80,7 +85,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
uint8_t * dst_ptr = out->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
(void *) out,
hexagon::get_type_name(out->get_type()));
return;
}
@ -118,7 +124,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
const npu_device_fp16_t * mp =
mask_ptr ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1) +
(iq3 % mask->get_ne(2)) * mask->get_nb(2)) :
(iq2 % mask->get_ne(2)) * mask->get_nb(2) +
(iq3 % mask->get_ne(3)) * mask->get_nb(3)) :
nullptr;
// k indices
@ -251,8 +258,8 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
const auto * v = out->get_src(2);
const auto * mask = out->get_src(3);
if (!q || !k || !v || !mask) {
DEVICE_LOG_DEBUG("invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v,
(void *) mask);
DEVICE_LOG_DEBUG(
"invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v, (void *) mask);
return false;
}
@ -264,8 +271,11 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
return true;
}
bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs, size_t src_len) {
bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
const auto op = op_spec->op;
if (op != NPU_OP_FLASH_ATTN) {
DEVICE_LOG_DEBUG("op is not NPU_OP_FLASH_ATTN: %d\n", op);
return false;
@ -295,7 +305,9 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
const auto * v = &srcs[2];
if (v->type != k->type) { // TODO: support more v types
DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n", op_get_name(op), get_type_name(v->type),
DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n",
op_get_name(op),
get_type_name(v->type),
get_type_name(k->type));
return false;
}
@ -310,28 +322,42 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
DEVICE_LOG_DEBUG(
"[%s]dst shape does not match q and v: dst ne: %ld, %ld, %ld, %ld, q ne: %ld, %ld, %ld, %ld, "
"v ne: %ld, %ld, %ld, %ld\n",
op_get_name(op), dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3],
v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
op_get_name(op),
dst->ne[0],
dst->ne[1],
dst->ne[2],
dst->ne[3],
q->ne[0],
q->ne[1],
q->ne[2],
q->ne[3],
v->ne[0],
v->ne[1],
v->ne[2],
v->ne[3]);
return false;
}
if (is_transposed_or_permuted(dst->nb)) {
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n", op_get_name(op),
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n",
op_get_name(op),
dst->nb[0],
dst->nb[1],
dst->nb[2],
dst->nb[3]);
return false;
}
if (q->ne[0] != k->ne[0]) {
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
k->ne[3]);
return false;
}
if (q->ne[2] != k->ne[2] || q->ne[3] != k->ne[3] || q->ne[3] != 1) {
// TODO: add broadcast support
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
op_get_name(op),
q->ne[0],
q->ne[1],
q->ne[2],
q->ne[3],
k->ne[0],
k->ne[1],
k->ne[2],
k->ne[3]);
return false;
}

View File

@ -5,7 +5,9 @@
namespace hexagon {
bool flash_attn_f32(tensor * out, compute_params * params);
bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs, size_t src_len);
bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
} // namespace hexagon

View File

@ -2,20 +2,20 @@
#include "op_impl.hpp"
#include <type_traits>
#include "op_flash_attn.hpp"
#include "op_mul_mat.hpp"
#include "op_rope.hpp"
#include "type_traits.hpp"
#include "vec_ops.hpp"
#include <type_traits>
namespace {
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
inline void vec_op_f32_f32(const float * src0, const float * src1, size_t count, float * dst) {
inline void vec_op_f32_f32(const float * src0, const float * src1, float * dst, size_t count) {
using namespace hexagon::vec;
vec_trans_op_impl<_OpBinaryTransform, float>(src0, src1, count, dst);
vec_trans_impl<_OpBinaryTransform, float>(src0, src1, dst, count);
}
inline HVX_Vector vadd_f32_f32(HVX_Vector a, HVX_Vector b) {
@ -31,10 +31,12 @@ inline HVX_Vector vmul_f32_f32(HVX_Vector a, HVX_Vector b) {
}
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
inline void vec_op_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count,
npu_device_fp16_t * dst) {
inline void vec_op_f16_f16(const npu_device_fp16_t * src0,
const npu_device_fp16_t * src1,
npu_device_fp16_t * dst,
size_t count) {
using namespace hexagon::vec;
vec_trans_op_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, count, dst);
vec_trans_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, dst, count);
}
inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) {
@ -53,12 +55,17 @@ inline HVX_Vector vmul_f16_f16(HVX_Vector a, HVX_Vector b) {
template <typename T> struct get_data_type {};
template <typename _TyData> struct get_data_type<void (*)(const _TyData *, const _TyData *, size_t, _TyData *)> {
template <typename _TyData> struct get_data_type<void (*)(const _TyData *, const _TyData *, _TyData *, size_t)> {
using type = _TyData;
};
template <typename _TyData>
struct get_data_type<void (*)(const _TyData *, const _TyData *, _TyData *, size_t, hexagon::HVX_VectorPair_x4)> {
using type = _TyData;
};
template <typename _TyData, typename _TyParam>
struct get_data_type<void (*)(const _TyData *, size_t, _TyParam, _TyData *)> {
struct get_data_type<void (*)(const _TyData *, _TyData *, size_t, _TyParam)> {
using type = _TyData;
using param_type = typename std::remove_cv<typename std::remove_reference<_TyData>::type>::type;
};
@ -85,7 +92,8 @@ template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::co
uint8_t * dst_ptr = out->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n",
(void *) out,
hexagon::get_type_name(out->get_type()));
return false;
}
@ -119,16 +127,21 @@ template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::co
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes);
}
_RowFunc(reinterpret_cast<const data_type *>(src0_row), reinterpret_cast<const data_type *>(src1_row),
static_cast<size_t>(out->get_ne(0)), reinterpret_cast<data_type *>(dst_row));
_RowFunc(reinterpret_cast<const data_type *>(src0_row),
reinterpret_cast<const data_type *>(src1_row),
reinterpret_cast<data_type *>(dst_row),
static_cast<size_t>(out->get_ne(0)));
}
out->release_write_buffer(); // mark the output tensor as modified
return true;
}
bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs, size_t src_len) {
bool is_element_wise_op_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
const auto op = op_spec->op;
if (op != NPU_OP_ADD && op != NPU_OP_SUB && op != NPU_OP_MUL) {
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
return false;
@ -142,27 +155,31 @@ bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tens
const auto & src0 = srcs[0];
const auto & src1 = srcs[1];
if (dst->type != src0.type || dst->type != src1.type) {
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
hexagon::op_get_name(op),
hexagon::get_type_name(src0.type),
hexagon::get_type_name(dst->type));
return false;
}
if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) {
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(dst->type));
DEVICE_LOG_DEBUG(
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
return false;
}
// TODO: fix FP16 add/sub
if (dst->type == NPU_DATA_TYPE_F16 && op != NPU_OP_MUL) {
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(dst->type));
DEVICE_LOG_DEBUG(
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
return false;
}
if (src0.ne[0] != src1.ne[0]) {
DEVICE_LOG_DEBUG("[%s]src0.ne[0] and src1.ne[0] not match: %ld vs %ld\n", hexagon::op_get_name(op),
(long) src0.ne[0], (long) src1.ne[0]);
DEVICE_LOG_DEBUG("[%s]src0.ne[0] and src1.ne[0] not match: %ld vs %ld\n",
hexagon::op_get_name(op),
(long) src0.ne[0],
(long) src1.ne[0]);
return false;
}
@ -174,7 +191,7 @@ bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tens
return true;
}
void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) {
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(float);
HVX_Vector * src_vec_ptr = ((HVX_Vector *) src);
@ -206,7 +223,7 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum,
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes));
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes));
}
const float mean = hexagon::vec_reduction_f32_qf32(sum) / count; // TODO: figure out how to do division in vector
@ -231,7 +248,8 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
auto * dst_ptr = out->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n",
(void *) out,
hexagon::get_type_name(out->get_type()));
return false;
}
@ -259,16 +277,21 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
}
_RowFunc(reinterpret_cast<const data_type *>(src0_row), static_cast<size_t>(out->get_ne(0)), param,
reinterpret_cast<data_type *>(dst_row));
_RowFunc(reinterpret_cast<const data_type *>(src0_row),
reinterpret_cast<data_type *>(dst_row),
static_cast<size_t>(out->get_ne(0)),
param);
}
out->release_write_buffer(); // mark the output tensor as modified
return true;
}
bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs, size_t src_len) {
bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
const auto op = op_spec->op;
if (op != NPU_OP_RMS_NORM) {
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
return false;
@ -281,14 +304,16 @@ bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec
const auto & src0 = srcs[0];
if (dst->type != src0.type) {
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
hexagon::op_get_name(op),
hexagon::get_type_name(src0.type),
hexagon::get_type_name(dst->type));
return false;
}
if (dst->type != NPU_DATA_TYPE_F32) {
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(dst->type));
DEVICE_LOG_DEBUG(
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
return false;
}
@ -300,6 +325,171 @@ bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec
return true;
}
inline void glu_vec_op_f32_f32(const float * src0,
const float * src1,
float * dst,
size_t count,
hexagon::HVX_VectorPair_x4 coeff) {
using namespace hexagon::vec;
vec_trans_with_param_impl<float, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f32_f32>(
src0, src1, dst, count, coeff);
}
inline void glu_vec_op_f16_f16(const npu_device_fp16_t * src0,
const npu_device_fp16_t * src1,
npu_device_fp16_t * dst,
size_t count,
hexagon::HVX_VectorPair_x4 coeff) {
using namespace hexagon::vec;
vec_trans_with_param_impl<npu_device_fp16_t, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f16_f16>(
src0, src1, dst, count, coeff);
}
template <auto _GluRowFunc, hexagon::HVX_VectorPair_x4 (*_CoeffLoadFunc)()>
bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) {
using data_type = typename get_data_type<decltype(_GluRowFunc)>::type;
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "element_wise_op requires max dims 4");
if (!out) {
return false;
}
const bool has_src1 = out->get_src(1) != nullptr;
auto * src0 = out->get_src(0);
auto * src1 = has_src1 ? out->get_src(1) : src0;
if (!src0 || !src1) {
return true; // skip if no src
}
const auto total_cols = has_src1 ? src0->get_ne(0) : src0->get_ne(0) / 2;
if (out->get_ne(0) != total_cols) {
DEVICE_LOG_ERROR("out.ne[0] (%ld) != total_cols (%d)\n", (long) out->get_ne(0), (int) total_cols);
return false;
}
auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
const auto rows_per_cube = out->get_ne(2) * out->get_ne(1);
const auto start_end = params->get_work_slice(total_rows);
if (start_end.first >= start_end.second) {
return true;
}
uint8_t * dst_ptr = out->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n",
(void *) out,
hexagon::get_type_name(out->get_type()));
return false;
}
const int32_t swapped = out->get_op_param<int32_t>(1);
const uint8_t * src0_ptr = src0->get_read_buffer();
const uint8_t * src1_ptr = has_src1 ? src1->get_read_buffer() : (src0_ptr + total_cols * sizeof(data_type));
if (swapped) {
std::swap(src0_ptr, src1_ptr);
}
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
auto coeff = _CoeffLoadFunc();
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type);
for (int64_t ir = start_end.first; ir < start_end.second; ++ir) {
const auto i03 = ir / rows_per_cube;
const auto i02 = ir / out->get_ne(1) - i03 * out->get_ne(2);
const auto i01 = ir % out->get_ne(1); // TODO: should we use divide instead of mod?
const auto i13 = i03 % src1->get_ne(3);
const auto i12 = i02 % src1->get_ne(2);
const auto i11 = i01 % src1->get_ne(1);
auto * src1_plane = src1_ptr + i13 * src1->get_nb(3) + i12 * src1->get_nb(2);
auto * src0_row = src0_ptr + i03 * src0->get_nb(3) + i02 * src0->get_nb(2) + i01 * src0->get_nb(1);
auto * src1_row = src1_plane + i11 * src1->get_nb(1);
auto * dst_row = dst_ptr + i03 * out->get_nb(3) + i02 * out->get_nb(2) + i01 * out->get_nb(1);
if (ir + 1 < start_end.second) {
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes);
}
_GluRowFunc(reinterpret_cast<const data_type *>(src0_row),
reinterpret_cast<const data_type *>(src1_row),
reinterpret_cast<data_type *>(dst_row),
static_cast<size_t>(total_cols),
coeff);
}
out->release_write_buffer(); // mark the output tensor as modified
return true;
}
template <npu_device_tensor_data_type _DataType>
bool glu_compute(hexagon::tensor * out, hexagon::compute_params * params) {
using namespace hexagon::vec::math;
if (out->get_op_param<int32_t>(0) != NPU_GLU_OP_SWIGLU) {
DEVICE_LOG_ERROR("Invalid GLU op type: %d\n", out->get_op_param<int32_t>(0));
return false;
}
if (out->get_type() != _DataType) {
DEVICE_LOG_ERROR("GLU op type mismatch: %s vs %s\n",
hexagon::get_type_name(out->get_type()),
hexagon::get_type_name(_DataType));
return false;
}
if constexpr (_DataType == NPU_DATA_TYPE_F32) {
return glu_impl<glu_vec_op_f32_f32, qhmath_load_div_sf_ltu>(out, params);
} else if constexpr (_DataType == NPU_DATA_TYPE_F16) {
return glu_impl<glu_vec_op_f16_f16, qhmath_load_div_hf_ltu>(out, params);
}
DEVICE_LOG_ERROR("Unsupported GLU data type: %s\n", hexagon::get_type_name(out->get_type()));
return true;
}
bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
const auto op = op_spec->op;
if (op != NPU_OP_GLU) {
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
return false;
}
if (op_spec->params[0] != NPU_GLU_OP_SWIGLU) {
DEVICE_LOG_DEBUG("[%s]unsupported GLU op type: %d\n", hexagon::op_get_name(op), op_spec->params[0]);
return false;
}
if (!dst || !srcs || src_len < 1) {
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op));
return false;
}
const auto & src0 = srcs[0];
if (dst->type != src0.type) {
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
hexagon::op_get_name(op),
hexagon::get_type_name(src0.type),
hexagon::get_type_name(dst->type));
return false;
}
if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) {
DEVICE_LOG_DEBUG(
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
return false;
}
if (!hexagon::is_same_shape(src0, *dst)) {
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
return false;
}
return false; // TODO: fix: for some input hexagon intrinsics will generate nan instead of inf.
}
struct op_capabilities {
npu_device_tensor_op op;
hexagon::op_is_supported_func_type is_supported;
@ -341,7 +531,7 @@ constexpr const op_capabilities kOpCapabilities[] = {
{
unary_op<rms_norm_vec_f32>, // NPU_DATA_TYPE_F32
nullptr, // NPU_DATA_TYPE_F16
}, false, // requires_thread_barrier
}, false, // requires_thread_barrier
},
{
NPU_OP_FLASH_ATTN,hexagon::is_flash_attn_supported,
@ -357,6 +547,13 @@ constexpr const op_capabilities kOpCapabilities[] = {
nullptr, // NPU_DATA_TYPE_F16
}, false, // requires_thread_barrier
},
{
NPU_OP_GLU, is_glu_op_supported,
{
glu_compute<NPU_DATA_TYPE_F32>, // NPU_DATA_TYPE_F32
glu_compute<NPU_DATA_TYPE_F16>, // NPU_DATA_TYPE_F16
}, false, // requires_thread_barrier
},
};
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
@ -370,6 +567,7 @@ static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE");
static_assert(kOpCapabilities[NPU_OP_GLU].op == NPU_OP_GLU, "kOpArray[NPU_OP_GLU].op != NPU_OP_GLU");
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
if (op >= NPU_OP_COUNT) {
@ -395,17 +593,25 @@ bool requires_thread_barrier(npu_device_tensor_op op) {
return kOpCapabilities[op].requires_thread_barrier;
}
bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
size_t src_len) {
auto is_supported_func = kOpCapabilities[op].is_supported;
if (!is_supported_func || !is_supported_func(op, dst, srcs, src_len)) {
bool support_op(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
if (!op_spec) {
DEVICE_LOG_ERROR("[hexagon-npu]invalid op_spec\n");
return false;
}
const auto op = op_spec->op;
auto is_supported_func = kOpCapabilities[op].is_supported;
if (!is_supported_func || !is_supported_func(op_spec, dst, srcs, src_len)) {
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
return false;
}
if (get_compute_func_impl(op, dst->type) == nullptr) {
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
get_type_name(dst->type));
DEVICE_LOG_DEBUG(
"[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op), get_type_name(dst->type));
return false;
}

View File

@ -8,7 +8,9 @@ compute_func_type get_compute_func(tensor * dst);
bool requires_thread_barrier(npu_device_tensor_op op);
bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
size_t src_len);
bool support_op(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
} // namespace hexagon

View File

@ -29,7 +29,9 @@ template <> struct convert_vector<npu_device_fp16_t> {
};
template <auto _DotFunc, bool _ShouldCacheSrc0>
void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst,
void mul_mat_impl(hexagon::tensor * src0,
hexagon::tensor * src1,
hexagon::tensor * dst,
hexagon::compute_params * params) {
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
@ -62,8 +64,12 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
DEVICE_LOG_DEBUG(
"mul_mat_impl: no work to do, start_end_plane: (%ld, %ld), start_end_row: (%ld, %ld), "
"start_end_element: (%ld, %ld)\n",
start_end_plane.first, start_end_plane.second, start_end_row.first, start_end_row.second,
start_end_element.first, start_end_element.second);
start_end_plane.first,
start_end_plane.second,
start_end_row.first,
start_end_row.second,
start_end_element.first,
start_end_element.second);
return;
}
@ -81,7 +87,9 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
DEVICE_LOG_ERROR(
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
"src0_actual_row_size: %zu, will fallback to mem cache\n",
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size);
src0_plane_cache_size,
src0_plane_slice_row_count,
src0_actual_row_size);
return;
}
}
@ -89,7 +97,10 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
DEVICE_LOG_DEBUG(
"mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
"%p(%zu)\n",
src0_actual_row_size, src0_plane_slice_row_count, _ShouldCacheSrc0, (void *) src0_plane_cache_ptr,
src0_actual_row_size,
src0_plane_slice_row_count,
_ShouldCacheSrc0,
(void *) src0_plane_cache_ptr,
src0_plane_cache_size);
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0);
@ -99,7 +110,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
uint8_t * dst_ptr = dst->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst,
DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
(void *) dst,
hexagon::get_type_name(dst->get_type()));
return;
}
@ -114,16 +126,17 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2);
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
col_idx += src0_plane_slice_row_count) {
const uint8_t * src0_plane =
src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1);
hexagon::l2fetch_row(src0_plane, src0->get_nb(1));
const int64_t actual_row_count =
std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - col_idx); // number of rows in this slice
const uint8_t * src0_plane =
src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1);
if constexpr (_ShouldCacheSrc0) {
if (last_cached_plane_ptr != src0_plane) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
hexagon::l2fetch_row(src0_plane, src0->get_nb(1));
for (int64_t ir = 0; ir < actual_row_count; ir++) {
auto * src0_row = src0_plane + ir * src0->get_nb(1);
if (ir + 1 < actual_row_count) {
@ -131,7 +144,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
}
auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size;
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_target_type *>(cached_row_ptr),
dequantize_row_func(src0_row,
reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
src0->get_ne(0));
}
@ -158,12 +172,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
// TODO: figure dst how to handle a entire row
auto res0 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
{
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
dst_row[i0] = convert_vector<data_type1>::convert(res0);
}
reinterpret_cast<const data_type1 *>(src1_row),
(size_t) src0->get_ne(0));
if (should_fetch_src0_row && i0 + 2 < actual_row_count) {
hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes);
@ -171,10 +181,12 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
// TODO: figure dst how to handle a entire row
auto res1 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_actual_row_size),
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
reinterpret_cast<const data_type1 *>(src1_row),
(size_t) src0->get_ne(0));
{
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
dst_row[i0] = convert_vector<data_type1>::convert(res0);
dst_row[i0 + 1] = convert_vector<data_type1>::convert(res1);
}
}
@ -186,7 +198,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
if (i0 < actual_row_count) {
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
auto res = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
reinterpret_cast<const data_type1 *>(src1_row),
(size_t) src0->get_ne(0));
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
dst_row[i0] = convert_vector<data_type1>::convert(res);
}
@ -197,19 +210,194 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
dst->release_write_buffer(); // mark the output tensor as modified
}
bool is_row_size_cacheable(const npu_device_tensor_spec & src) {
const auto & type_traits = hexagon::get_type_traits(src.type);
if (type_traits.to_float == nullptr) {
DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) cannot be cached, to_float is null\n",
hexagon::get_type_name(src.type));
template <auto _DotFunc, bool _ShouldCacheSrc0>
void mul_mat_gemv_impl(hexagon::tensor * src0,
hexagon::tensor * src1,
hexagon::tensor * dst,
hexagon::compute_params * params) {
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0);
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
if (_ShouldCacheSrc0 && dequantize_row_func == nullptr) {
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
return;
}
auto start_end_element = std::pair<int64_t, int64_t>{ 0, dst->get_ne(0) };
if (dst->get_ne(0) >= params->get_thread_count()) {
start_end_element = params->get_work_slice(dst->get_ne(0));
} else {
DEVICE_LOG_ERROR("Unsupported src1 tensor shape for gemv: %s, ne: %ldx%ldx%ldx%ld\n",
hexagon::get_type_name(src1->get_type()),
src1->get_ne(0),
src1->get_ne(1),
src1->get_ne(2),
src1->get_ne(3));
return;
}
if (start_end_element.second <= start_end_element.first) {
DEVICE_LOG_DEBUG(
"mul_mat_impl: no work to do, start_end_plane: [0, 1), start_end_row: [0, 1), "
"start_end_element: [%ld, %ld)\n",
start_end_element.first,
start_end_element.second);
return;
}
// cache the src0 plane in VTCM
size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first;
size_t src0_plane_cache_size = 0;
uint8_t * src0_plane_cache_ptr = nullptr;
const auto src1_actual_row_size = hexagon::get_aligned_size(src1->get_nb(1));
uint8_t * src1_row_cache_ptr = nullptr;
if constexpr (_ShouldCacheSrc0) {
src0_plane_slice_row_count = std::min(
(params->get_vtcm_quota_size() - src1_actual_row_size) / src0_actual_row_size, src0_plane_slice_row_count);
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
src0_plane_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size + src1_actual_row_size);
if (src0_plane_cache_ptr == nullptr) {
DEVICE_LOG_ERROR(
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
"src0_actual_row_size: %zu, will fallback to mem cache\n",
src0_plane_cache_size,
src0_plane_slice_row_count,
src0_actual_row_size);
return;
}
src1_row_cache_ptr = src0_plane_cache_ptr + src0_plane_cache_size;
} else {
src1_row_cache_ptr = params->get_vtcm_cache(src1_actual_row_size);
if (src1_row_cache_ptr == nullptr) {
DEVICE_LOG_ERROR("mul_mat_impl: failed to get VTCM cache for src1, size: %zu\n", src1_actual_row_size);
return;
}
}
DEVICE_LOG_DEBUG(
"mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
"%p(%zu)\n",
src0_actual_row_size,
src0_plane_slice_row_count,
_ShouldCacheSrc0,
(void *) src0_plane_cache_ptr,
src0_plane_cache_size);
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0);
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
uint8_t * dst_ptr = dst->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
(void *) dst,
hexagon::get_type_name(dst->get_type()));
return;
}
constexpr bool should_fetch_src0_row = !_ShouldCacheSrc0;
const uint8_t * src0_ptr = src0->get_read_buffer();
const uint8_t * src1_ptr = src1->get_read_buffer();
{
memcpy(src1_row_cache_ptr, src1_ptr, src1->get_ne(0) * sizeof(data_type1));
src1_ptr = src1_row_cache_ptr;
}
{
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
col_idx += src0_plane_slice_row_count) {
const uint8_t * src0_plane = src0_ptr + col_idx * src0->get_nb(1);
hexagon::l2fetch_row(src0_plane, src0->get_nb(1));
const int64_t actual_row_count =
std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - col_idx); // number of rows in this slice
if constexpr (_ShouldCacheSrc0) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
for (int64_t ir = 0; ir < actual_row_count; ir++) {
auto * src0_row = src0_plane + ir * src0->get_nb(1);
if (ir + 1 < actual_row_count) {
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
}
auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size;
dequantize_row_func(
src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr), src0->get_ne(0));
}
src0_plane = src0_plane_cache_ptr;
}
{
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot);
auto * dst_row = reinterpret_cast<float *>(dst_ptr) + col_idx;
int64_t i0 = 0;
for (; i0 + 1 < actual_row_count; i0 += 2) {
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
if constexpr (should_fetch_src0_row) {
hexagon::l2fetch_row(src0_row + src0_actual_row_size, valid_row0_bytes);
}
// TODO: figure dst how to handle a entire row
auto res0 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
reinterpret_cast<const data_type1 *>(src1_ptr),
(size_t) src0->get_ne(0));
if (should_fetch_src0_row && i0 + 2 < actual_row_count) {
hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes);
}
// TODO: figure dst how to handle a entire row
auto res1 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_actual_row_size),
reinterpret_cast<const data_type1 *>(src1_ptr),
(size_t) src0->get_ne(0));
{
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
dst_row[i0] = convert_vector<data_type1>::convert(res0);
dst_row[i0 + 1] = convert_vector<data_type1>::convert(res1);
}
}
if (i0 < actual_row_count) {
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
auto res = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
reinterpret_cast<const data_type1 *>(src1_ptr),
(size_t) src0->get_ne(0));
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
dst_row[i0] = convert_vector<data_type1>::convert(res);
}
}
}
}
dst->release_write_buffer(); // mark the output tensor as modified
}
bool is_src_cacheable(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
const auto & src0_type_traits = hexagon::get_type_traits(src0.type);
if (src0_type_traits.to_float == nullptr) {
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) cannot be cached, to_float is null\n",
hexagon::get_type_name(src0.type));
return false;
}
const size_t type_size = type_traits.is_quantized ? sizeof(hexagon::dequant_target_type) : type_traits.type_size;
const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota();
if (src.ne[0] * type_size > vtcm_thread_quota_size) {
DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n",
hexagon::get_type_name(src.type), (long) src.ne[0], vtcm_thread_quota_size);
const size_t src0_type_size =
src0_type_traits.is_quantized ? sizeof(hexagon::dequant_output_type) : src0_type_traits.type_size;
const auto & src1_type_traits = hexagon::get_type_traits(src1.type);
const bool is_gemv = src1.ne[1] == 1 && src1.ne[2] == 1 && src1.ne[3] == 1;
size_t min_cache_size = is_gemv ? (src1.ne[0] * src1_type_traits.type_size) : 0;
min_cache_size += src0.ne[0] * src0_type_size;
if (min_cache_size > vtcm_thread_quota_size) {
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) min_cache_size is too large: %ld, vtcm_thread_quota_size: %zu\n",
hexagon::get_type_name(src0.type),
(long) min_cache_size,
vtcm_thread_quota_size);
return false;
}
@ -219,36 +407,41 @@ bool is_row_size_cacheable(const npu_device_tensor_spec & src) {
bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
if (src1.type != NPU_DATA_TYPE_F32 && src1.type != NPU_DATA_TYPE_F16) {
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src1 is not F32\n",
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
hexagon::get_type_name(src0.type),
hexagon::get_type_name(src1.type));
return false;
}
const auto type_traits = hexagon::get_type_traits(src0.type);
if (!type_traits.is_quantized || type_traits.to_float == nullptr) {
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src0 is not quantized\n",
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
hexagon::get_type_name(src0.type),
hexagon::get_type_name(src1.type));
return false;
}
if (src0.ne[0] % type_traits.blck_size) {
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) ne[0] is not aligned: %ld\n", hexagon::get_type_name(src0.type),
(long) src0.ne[0]);
DEVICE_LOG_DEBUG(
"[MUL_MAT]src0.type(%s) ne[0] is not aligned: %ld\n", hexagon::get_type_name(src0.type), (long) src0.ne[0]);
return false;
}
if (!is_row_size_cacheable(src0)) {
if (!is_src_cacheable(src0, src1)) {
return false;
}
DEVICE_LOG_DEBUG("[MUL_MAT]supported quantized src0.type(%s) and src1.type(%s)\n",
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
hexagon::get_type_name(src0.type),
hexagon::get_type_name(src1.type));
return true;
}
bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) {
const auto * src1_ptr = src1->get_read_buffer_as<float>();
const auto * src0_ptr =
is_src0_quantized ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>(); // skip src0 for quantized tensors
bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0,
hexagon::tensor * src1,
bool is_src0_cached,
bool is_src1_cached) {
const auto * src1_ptr = is_src1_cached ? nullptr : src1->get_read_buffer_as<float>();
const auto * src0_ptr = is_src0_cached ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>();
if (!hexagon::is_f16_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
@ -305,35 +498,49 @@ bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::ten
return true;
}
typedef void (*mul_mat_func_type)(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst,
typedef void (*mul_mat_func_type)(hexagon::tensor * src0,
hexagon::tensor * src1,
hexagon::tensor * dst,
hexagon::compute_params * params);
constexpr const mul_mat_func_type kMulMatF32F32CachedFuncs[2] = {
constexpr const size_t kMulMatGemvBaseIndex = 2;
constexpr const mul_mat_func_type kMulMatF32F32CachedFuncs[4] = {
// quantized and non-quantized
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, true>, // F32 * F32 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized aligned
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, true>, // F32 * F32 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized aligned
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized gemv
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized gemv
};
constexpr const mul_mat_func_type kMulMatF32F32Funcs[2] = {
constexpr const mul_mat_func_type kMulMatF32F32Funcs[4] = {
// quantized and non-quantized
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 quantized aligned
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 quantized aligned
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 quantized gemv
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 quantized gemv
};
constexpr const mul_mat_func_type kMulMatF16CachedFuncs[2] = {
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, true>, // F16 * F16 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized aligned
};
constexpr const mul_mat_func_type kMulMatF16Funcs[2] = {
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 quantized aligned
};
constexpr const mul_mat_func_type kMulMatF16F32Funcs[2] = {
constexpr const mul_mat_func_type kMulMatF16F32Funcs[4] = {
// quantized and non-quantized
mul_mat_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F32 * F32 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F32 * F32 quantized aligned
mul_mat_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F32 * F32 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F32 * F32 quantized aligned
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F32 * F32 quantized unaligned
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F32 * F32 quantized aligned
};
constexpr const mul_mat_func_type kMulMatF16CachedFuncs[4] = {
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, true>, // F16 * F16 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized aligned
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized gemv
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized gemv
};
constexpr const mul_mat_func_type kMulMatF16Funcs[4] = {
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 quantized unaligned
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 quantized aligned
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 quantized gemv
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 quantized gemv
};
} // namespace
@ -342,9 +549,9 @@ namespace hexagon {
bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "mul_mat_f32 requires max dims 4");
static_assert(std::is_same<hexagon::dequant_target_type, float>::value ||
std::is_same<hexagon::dequant_target_type, npu_device_fp16_t>::value,
"dequant_target_type must be float or npu_device_fp16_t");
static_assert(std::is_same<hexagon::dequant_output_type, float>::value ||
std::is_same<hexagon::dequant_output_type, npu_device_fp16_t>::value,
"dequant_output_type must be float or npu_device_fp16_t");
if (!out) {
return false;
@ -358,24 +565,28 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
const bool is_src0_quantized = is_quantized_type(src0->get_type());
const bool should_cache_src0 = is_src0_quantized || src1->get_ne(1) > 1;
const bool is_gemv = src1->get_ne(1) == 1 && src1->get_ne(2) == 1 && src1->get_ne(3) == 1;
const auto base_index = is_gemv ? kMulMatGemvBaseIndex : 0;
switch (src1->get_type()) {
case NPU_DATA_TYPE_F32:
if (is_src0_quantized || src0->get_type() == NPU_DATA_TYPE_F16) {
kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1,
out, params);
kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, is_src0_quantized, is_gemv) +
base_index](src0, src1, out, params);
} else if (should_cache_src0) {
kMulMatF32F32CachedFuncs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params);
kMulMatF32F32CachedFuncs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1) + base_index](
src0, src1, out, params);
} else {
kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params);
kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1) + base_index](
src0, src1, out, params);
}
return true;
case NPU_DATA_TYPE_F16:
if (should_cache_src0) {
kMulMatF16CachedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](
src0, src1, out, params);
kMulMatF16CachedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) +
base_index](src0, src1, out, params);
} else {
kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1, out,
params);
kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) + base_index](
src0, src1, out, params);
}
return true;
default:
@ -386,8 +597,11 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
return false;
}
bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs, size_t src_len) {
bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
const auto op = op_spec->op;
if (op != NPU_OP_MUL_MAT) {
DEVICE_LOG_DEBUG("op is not MUL_MAT: %d\n", op);
return false;
@ -408,7 +622,9 @@ bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec
if (src0.type != src1.type) {
if (src1.type == NPU_DATA_TYPE_F32 && src0.type == NPU_DATA_TYPE_F16) {
DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch, but src0 is F16 and src1 is F32\n",
op_get_name(op), get_type_name(src0.type), get_type_name(src1.type));
op_get_name(op),
get_type_name(src0.type),
get_type_name(src1.type));
return true; // F16 * F32 is supported
}
@ -418,26 +634,40 @@ bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec
}
#else
DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch and quantized tensors are not supported\n",
op_get_name(op), get_type_name(src0.type), get_type_name(src1.type));
op_get_name(op),
get_type_name(src0.type),
get_type_name(src1.type));
return false;
#endif
}
if (src0.ne[0] != src1.ne[0] || src0.ne[1] != dst->ne[0]) {
DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[0],
(long) src0.ne[1], (long) src1.ne[0], (long) src1.ne[1]);
DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n",
op_get_name(op),
(long) src0.ne[0],
(long) src0.ne[1],
(long) src1.ne[0],
(long) src1.ne[1]);
return false;
}
if (src1.ne[1] != dst->ne[1] || src1.ne[2] != dst->ne[2] || src1.ne[3] != dst->ne[3]) {
DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n", op_get_name(op),
(long) src1.ne[2], (long) src1.ne[3], (long) dst->ne[2], (long) dst->ne[3]);
DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n",
op_get_name(op),
(long) src1.ne[2],
(long) src1.ne[3],
(long) dst->ne[2],
(long) dst->ne[3]);
return false;
}
if (src1.ne[2] % src0.ne[2] || src1.ne[3] % src0.ne[3]) {
DEVICE_LOG_DEBUG("[%s]src0 cannot broadcast to src1: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[2],
(long) src0.ne[3], (long) src1.ne[2], (long) src1.ne[3]);
DEVICE_LOG_DEBUG("[%s]src0 cannot broadcast to src1: %ldx%ld vs %ldx%ld\n",
op_get_name(op),
(long) src0.ne[2],
(long) src0.ne[3],
(long) src1.ne[2],
(long) src1.ne[3]);
return false;
}

View File

@ -1,14 +1,16 @@
#pragma once
#include <hexagon_types.h>
#include "op_types.hpp"
#include "tensor.hpp"
#include <hexagon_types.h>
namespace hexagon {
bool mul_mat_f32(tensor * out, compute_params * params);
bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs, size_t src_len);
bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
} // namespace hexagon

View File

@ -29,8 +29,14 @@ float rope_yarn_ramp(const float low, const float high, const int i0) {
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta) {
void rope_yarn(float theta_extrap,
float freq_scale,
float corr_dims[2],
int64_t i0,
float ext_factor,
float mscale,
float * cos_theta,
float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
@ -45,8 +51,16 @@ void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t
*sin_theta = sinf(theta) * mscale;
}
void rope_cache_init(float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0,
float ext_factor, float mscale, float * cache, float sin_sign, float theta_scale) {
void rope_cache_init(float theta_base,
float freq_scale,
const float * freq_factors,
float corr_dims[2],
int64_t ne0,
float ext_factor,
float mscale,
float * cache,
float sin_sign,
float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta = theta_base;
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@ -58,10 +72,21 @@ void rope_cache_init(float theta_base, float freq_scale, const float * freq_fact
}
}
void mrope_cache_init(float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e,
const int sections[4], bool indep_sects, float freq_scale, const float * freq_factors,
float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float * cache, float sin_sign,
float theta_scale) {
void mrope_cache_init(float theta_base_t,
float theta_base_h,
float theta_base_w,
float theta_base_e,
const int sections[4],
bool indep_sects,
float freq_scale,
const float * freq_factors,
float corr_dims[2],
int64_t ne0,
float ext_factor,
float mscale,
float * cache,
float sin_sign,
float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta_t = theta_base_t;
float theta_h = theta_base_h;
@ -181,16 +206,37 @@ bool rope_impl(hexagon::tensor * out, hexagon::compute_params * params) {
if constexpr (!_IsMrope) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache);
const int64_t p = pos[i2];
rope_cache_init(p, freq_scale, freq_factors, corr_dims, out->get_ne(0), ext_factor, attn_factor, cache,
sin_sign, theta_scale);
rope_cache_init(p,
freq_scale,
freq_factors,
corr_dims,
out->get_ne(0),
ext_factor,
attn_factor,
cache,
sin_sign,
theta_scale);
} else {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache);
const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + out->get_ne(2)];
const int64_t p_w = pos[i2 + out->get_ne(2) * 2];
const int64_t p_e = pos[i2 + out->get_ne(2) * 3];
mrope_cache_init(p_t, p_h, p_w, p_e, sections, _IsVision, freq_scale, freq_factors, corr_dims,
out->get_ne(0), ext_factor, attn_factor, cache, sin_sign, theta_scale);
mrope_cache_init(p_t,
p_h,
p_w,
p_e,
sections,
_IsVision,
freq_scale,
freq_factors,
corr_dims,
out->get_ne(0),
ext_factor,
attn_factor,
cache,
sin_sign,
theta_scale);
}
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 1, loop);
@ -316,8 +362,11 @@ bool rope_f32(tensor * out, compute_params * params) {
return kRopeImplFuncs[impl_index](out, params);
}
bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
size_t src_len) {
bool is_rope_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
const auto op = op_spec->op;
if (op != NPU_OP_ROPE) {
DEVICE_LOG_DEBUG("[%s]op is not ROPE\n", op_get_name(op));
return false;
@ -336,8 +385,10 @@ bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * d
const auto & src0 = srcs[0];
if (src0.type != dst->type) {
DEVICE_LOG_DEBUG("[%s]src0 type is not the same as dst type: %s vs %s\n", op_get_name(op),
get_type_name(src0.type), get_type_name(dst->type));
DEVICE_LOG_DEBUG("[%s]src0 type is not the same as dst type: %s vs %s\n",
op_get_name(op),
get_type_name(src0.type),
get_type_name(dst->type));
return false; // unsupported src0 type
}

View File

@ -5,7 +5,9 @@
namespace hexagon {
bool rope_f32(tensor * out, compute_params * params);
bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
size_t src_len);
bool is_rope_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
} // namespace hexagon

View File

@ -1,5 +1,11 @@
#pragma once
#include "hexagon_npu.h"
#include "tensor.hpp"
#include "thread_pool.hpp"
#include "util.hpp"
#include "vec_ops.hpp"
#include <hexagon_types.h>
#include <algorithm>
@ -7,12 +13,6 @@
#include <memory>
#include <utility>
#include "hexagon_npu.h"
#include "tensor.hpp"
#include "thread_pool.hpp"
#include "util.hpp"
#include "vec_ops.hpp"
namespace hexagon {
inline constexpr std::pair<int64_t, int64_t> get_thread_work_slice(int64_t total, size_t tidx, size_t tcnt) {
@ -44,8 +44,6 @@ struct compute_params {
uint8_t * get_vtcm_cache(size_t size) { return thread_params->get_vtcm_cache(size); }
uint8_t * get_mem_cache(size_t size) { return thread_params->get_mem_cache(size); }
std::pair<int64_t, int64_t> get_work_slice(int64_t total) const {
return get_thread_work_slice(total, thread_params->tidx, thread_params->tcnt);
}
@ -58,7 +56,9 @@ struct compute_params {
};
typedef bool (*compute_func_type)(tensor * dst, compute_params * params);
typedef bool (*op_is_supported_func_type)(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs, size_t src_len);
typedef bool (*op_is_supported_func_type)(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
} // namespace hexagon

View File

@ -1,13 +1,13 @@
#pragma once
#include "hexagon_npu.h"
#include "util.hpp"
#include <HAP_mem.h>
#include <qurt.h>
#include <atomic>
#include "hexagon_npu.h"
#include "util.hpp"
namespace hexagon {
constexpr const size_t kMaxTensorSrc = DEVICE_TENSOR_MAX_SRC;
@ -26,8 +26,15 @@ class tensor {
_data = static_cast<uint8_t *>(mmap_address);
DEVICE_LOG_INFO("tensor(%p[%ldx%ldx%ldx%ld]), fd: %d, offset: %zu, mmap_addr: %p, phy_addr: 0x%lx\n",
(void *) this, (long) _info.ne[0], (long) _info.ne[1], (long) _info.ne[2], (long) _info.ne[3],
_info.buffer_fd, _info.offset, (void *) mmap_address, phy_address);
(void *) this,
(long) _info.ne[0],
(long) _info.ne[1],
(long) _info.ne[2],
(long) _info.ne[3],
_info.buffer_fd,
_info.offset,
(void *) mmap_address,
phy_address);
}
~tensor() noexcept {
@ -41,15 +48,17 @@ class tensor {
void flush() const {
if (_data) {
qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size, QURT_MEM_CACHE_FLUSH,
QURT_MEM_DCACHE);
qurt_mem_cache_clean(
(qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size, QURT_MEM_CACHE_FLUSH, QURT_MEM_DCACHE);
}
}
void invalidate() const {
if (_data) {
qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size,
QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE);
qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset),
(qurt_size_t) _info.size,
QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL,
QURT_MEM_DCACHE);
}
}

View File

@ -1,5 +1,8 @@
#pragma once
#include "util.hpp"
#include "vtcm_mem.hpp"
#include <qurt.h>
#include <array>
@ -8,27 +11,26 @@
#include <memory>
#include <string>
#include "util.hpp"
#include "vtcm_mem.hpp"
namespace hexagon {
constexpr const size_t kMaxThreadCount = 4;
constexpr const size_t kDefaultStackSize = 1024 * 32; // 32KB
constexpr const unsigned long long kThreadTaskPendingBit = 1;
constexpr const size_t kMaxThreadCount = 4;
constexpr const size_t kDefaultStackSize = 1024 * 64; // 64KB
template <size_t _stack_size> class qurt_thread {
public:
typedef void (*qurt_thread_func_type)(qurt_thread * thread, void * arg);
explicit qurt_thread(const std::string & thread_name, qurt_thread_func_type thread_func, void * arg,
unsigned short priority) {
explicit qurt_thread(const std::string & thread_name,
qurt_thread_func_type thread_func,
void * arg,
unsigned short priority) {
DEVICE_LOG_DEBUG("qurt_thread.create: %s", thread_name.c_str());
qurt_thread_attr_init(&_attributes);
qurt_thread_attr_set_name(&_attributes, (char *) thread_name.c_str());
qurt_thread_attr_set_stack_addr(&_attributes, _stack);
qurt_thread_attr_set_stack_size(&_attributes, _stack_size);
qurt_thread_attr_set_priority(&_attributes, priority);
qurt_thread_attr_set_bus_priority(&_attributes, QURT_THREAD_BUS_PRIO_ENABLED);
_func = thread_func;
_arg = arg;
@ -94,9 +96,9 @@ template <size_t _ThreadCount> class thread_pool {
thread_pool<kMaxThreadCount> * pool = nullptr;
size_t vtcm_quota_size;
std::unique_ptr<vtcm_mem> vtcm_cache;
std::unique_ptr<uint8_t[]> mem_cache;
size_t mem_cache_size = 0;
std::unique_ptr<vtcm_mem> vtcm_cache;
void init_vtcm_cache() { vtcm_cache = std::make_unique<vtcm_mem>(vtcm_quota_size, false); }
uint8_t * get_vtcm_cache(size_t size) {
if (!vtcm_cache || vtcm_cache->get_size() < size) {
@ -111,25 +113,18 @@ template <size_t _ThreadCount> class thread_pool {
return vtcm_cache->get_mem();
}
uint8_t * get_mem_cache(size_t size) {
if (!mem_cache || mem_cache_size < size) {
mem_cache.reset(); // reset the cache to create a new one
mem_cache = std::make_unique<uint8_t[]>(size + 256);
mem_cache_size = mem_cache ? size : 0;
}
return mem_cache.get();
}
};
typedef void (*task_type)(thread_pool * pool, thread_params * param, void * arg);
thread_pool() {
for (size_t i = 0; i < kMaxThreadCount; ++i) {
_thread_params[i].tidx = i;
_thread_params[i].vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount;
_thread_params[i].pool = this;
auto & thread_param = _thread_params[i];
thread_param.tidx = i;
thread_param.vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount;
thread_param.pool = this;
thread_param.init_vtcm_cache();
}
qurt_barrier_init(&_pending, kMaxSubThreadCount + 1);
@ -215,7 +210,8 @@ template <size_t _ThreadCount> class thread_pool {
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
auto task_begin_cycles = pool._task_begin_cycles.load();
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, prepare: %lluus", param->tidx,
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, prepare: %lluus",
param->tidx,
static_cast<unsigned long long>(
HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles)));
#endif
@ -229,7 +225,8 @@ template <size_t _ThreadCount> class thread_pool {
qurt_barrier_wait(&pool._completed);
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, task_end: %lluus", param->tidx,
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, task_end: %lluus",
param->tidx,
static_cast<unsigned long long>(
HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles)));
#endif

View File

@ -1,12 +1,12 @@
#include "type_traits.hpp"
#include "op_types.hpp" // TODO: remove this include
#include "vec_ops.hpp"
#include <hexagon_types.h>
#include <array>
#include "op_types.hpp" // TODO: remove this include
#include "vec_ops.hpp"
static_assert(sizeof(npu_device_block_q4_k) ==
2 * sizeof(npu_device_fp16_t) + QUANT_K_SCALE_SIZE + QUANT_K_BLOCK_SIZE / 2,
"wrong q4_K block size/padding");
@ -82,8 +82,17 @@ inline int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}
float make_qkx2_quants(int n, int nmax, const float * x, const float * weights, uint8_t * L, float * the_min,
uint8_t * Laux, float rmin, float rdelta, int nstep, bool use_mad) {
float make_qkx2_quants(int n,
int nmax,
const float * x,
const float * weights,
uint8_t * L,
float * the_min,
uint8_t * Laux,
float rmin,
float rdelta,
int nstep,
bool use_mad) {
float min = x[0];
float max = x[0];
float sum_w = weights[0];
@ -315,13 +324,13 @@ void quantize_row_q4_K(const float * src, void * dst, size_t count) {
}
}
void dequantize_row_q8_0(const void * src, hexagon::dequant_target_type * dst, size_t count) {
void dequantize_row_q8_0(const void * src, hexagon::dequant_output_type * dst, size_t count) {
constexpr const int qk = QUANT_BLOCK_SIZE;
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
const int nb = count / qk;
const auto * src_ptr = reinterpret_cast<const npu_device_block_q8_0 *>(src);
auto * dst_ptr = ((hexagon::dequant_target_type *) dst); // TODO: opt for aligned access
auto * dst_ptr = ((hexagon::dequant_output_type *) dst); // TODO: opt for aligned access
int i = 0;
for (; i + 1 < nb; i += 2) {
@ -354,7 +363,7 @@ void dequantize_row_q8_0(const void * src, hexagon::dequant_target_type * dst, s
}
template <bool _IsDstAligned>
void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * dst, size_t count) {
void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * dst, size_t count) {
constexpr const int qk = QUANT_BLOCK_SIZE;
static_assert(qk % 2 == 0, "qk must be even");
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
@ -364,7 +373,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src);
const HVX_Vector mask = Q6_Vb_vsplat_R(0x0F);
const HVX_Vector minus = Q6_Vb_vsplat_R(8);
hexagon::dequant_target_type * dst_ptr = dst; // TODO: opt for aligned access
hexagon::dequant_output_type * dst_ptr = dst; // TODO: opt for aligned access
int i = 0;
for (; i + 3 < nb; i += 4) {
@ -402,7 +411,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = q_hi;
}
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type) * 2;
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type) * 2;
}
for (; i + 1 < nb; i += 2) {
@ -428,7 +437,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
*reinterpret_cast<HVX_UVector *>(dst_ptr) = q_lo;
}
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type);
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type);
}
if (i < nb) {
@ -453,7 +462,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
}
}
void dequantize_row_q4_0(const void * src, hexagon::dequant_target_type * dst, size_t count) {
void dequantize_row_q4_0(const void * src, hexagon::dequant_output_type * dst, size_t count) {
const bool dst_aligned = hexagon::is_addr_aligned(dst);
if (dst_aligned) {
dequantize_row_q4_0_impl<true>(src, dst, count);
@ -462,7 +471,7 @@ void dequantize_row_q4_0(const void * src, hexagon::dequant_target_type * dst, s
}
}
void dequantize_row_q4_K(const void * src, hexagon::dequant_target_type * dst, size_t count) {
void dequantize_row_q4_K(const void * src, hexagon::dequant_output_type * dst, size_t count) {
const int nb = count / QUANT_K_BLOCK_SIZE;
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_k *>(src);
auto * dst_ptr = reinterpret_cast<__fp16 *>(dst);
@ -497,29 +506,44 @@ void dequantize_row_q4_K(const void * src, hexagon::dequant_target_type * dst, s
}
}
void copy_row_f16(const void * src, hexagon::dequant_target_type * dst, size_t count) {
void copy_row_f16(const void * src, hexagon::dequant_output_type * dst, size_t count) {
hexagon::vec_cpy_f16(reinterpret_cast<const npu_device_fp16_t *>(src), dst, count);
}
void copy_row_f32(const void * src, hexagon::dequant_target_type * dst, size_t count) {
void copy_row_f32(const void * src, hexagon::dequant_output_type * dst, size_t count) {
hexagon::vec_cpy_f32(reinterpret_cast<const float *>(src), reinterpret_cast<float *>(dst), count);
}
constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32, nullptr,
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
{ NPU_DATA_TYPE_F32,
"F32", 1,
sizeof(float),
false, copy_row_f32,
nullptr, hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f32_f32>,
hexagon::type_erase_dot_func<hexagon::is_f32_f32_dot_product_aligned> },
{ NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, copy_row_f16, quantize_row_fp16,
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16>,
{ NPU_DATA_TYPE_F16,
"F16", 1,
sizeof(npu_device_fp16_t),
false, copy_row_f16,
quantize_row_fp16, hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16>,
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f16_f16>,
hexagon::type_erase_dot_func<hexagon::is_f16_f16_dot_product_aligned> },
{ NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false },
{ NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0,
{ NPU_DATA_TYPE_Q8_0,
"Q8_0", QUANT_BLOCK_SIZE,
sizeof(npu_device_block_q8_0),
true, dequantize_row_q8_0,
quantize_row_q8_0 },
{ NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0,
{ NPU_DATA_TYPE_Q4_0,
"Q4_0", QUANT_BLOCK_SIZE,
sizeof(npu_device_block_q4_0),
true, dequantize_row_q4_0,
quantize_row_q4_0 },
{ NPU_DATA_TYPE_Q4_K, "Q4_K", QUANT_K_BLOCK_SIZE, sizeof(npu_device_block_q4_k), true, dequantize_row_q4_K,
{ NPU_DATA_TYPE_Q4_K,
"Q4_K", QUANT_K_BLOCK_SIZE,
sizeof(npu_device_block_q4_k),
true, dequantize_row_q4_K,
quantize_row_q4_K },
};
@ -566,7 +590,7 @@ size_t get_dequantized_row_size(const tensor * tensor) {
auto row_elems_count = tensor->get_ne(0);
return hexagon::get_aligned_size(
row_elems_count * sizeof(dequant_target_type)); // dequant_target_type is currently restricted to f32
row_elems_count * sizeof(dequant_output_type)); // dequant_output_type is currently restricted to f32
}
} // namespace hexagon

View File

@ -5,12 +5,12 @@
namespace hexagon {
using dequant_target_type = npu_device_fp16_t;
using dequant_output_type = npu_device_fp16_t;
bool init_f16_f32_table(float * table, size_t count);
typedef void (*quantize_row_type)(const float * src, void * dst, size_t count);
typedef void (*dequantize_row_type)(const void * src, dequant_target_type * dst, size_t count);
typedef void (*dequantize_row_type)(const void * src, dequant_output_type * dst, size_t count);
typedef float (*vec_dot_type)(const void * src0, const void * src1, size_t count);
typedef bool (*can_use_aligned_vec_dot_type)(const void * src0, const void * src1, size_t count);
@ -51,14 +51,32 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx) {
auto * src1 = op->get_src(1);
char buffer[1024];
if (src1 == nullptr) {
snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s], tidx: %zu", op_get_name(op->get_op()),
src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3), get_type_name(src0->get_type()),
snprintf(buffer,
sizeof(buffer),
"[%s][%lldx%lldx%lldx%lld%s], tidx: %zu",
op_get_name(op->get_op()),
src0->get_ne(0),
src0->get_ne(1),
src0->get_ne(2),
src0->get_ne(3),
get_type_name(src0->get_type()),
tidx);
} else {
snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s],[%lldx%lldx%lldx%lld%s], tidx: %zu",
op_get_name(op->get_op()), src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3),
get_type_name(src0->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2), src1->get_ne(3),
get_type_name(src1->get_type()), tidx);
snprintf(buffer,
sizeof(buffer),
"[%s][%lldx%lldx%lldx%lld%s],[%lldx%lldx%lldx%lld%s], tidx: %zu",
op_get_name(op->get_op()),
src0->get_ne(0),
src0->get_ne(1),
src0->get_ne(2),
src0->get_ne(3),
get_type_name(src0->get_type()),
src1->get_ne(0),
src1->get_ne(1),
src1->get_ne(2),
src1->get_ne(3),
get_type_name(src1->get_type()),
tidx);
}
return npu_scoped_timer<1024>(buffer);
}

View File

@ -1,5 +1,7 @@
#pragma once
#include "hexagon_npu.h"
#include <AEEStdDef.h>
#include <HAP_farf.h>
#include <HAP_perf.h>
@ -9,8 +11,6 @@
#include <cstring>
#include <utility>
#include "hexagon_npu.h"
#define DEVICE_LOG_ERROR(...) FARF(FATAL, __VA_ARGS__)
#define DEVICE_LOG_WARN(...) FARF(ERROR, __VA_ARGS__)
#define DEVICE_LOG_INFO(...) FARF(HIGH, __VA_ARGS__)
@ -56,6 +56,8 @@ inline constexpr const char * op_get_name(npu_device_tensor_op op) {
return "FLASH_ATTN_EXT";
case NPU_OP_ROPE:
return "ROPE";
case NPU_OP_GLU:
return "GLU";
default:
return "UNKNOWN";
}
@ -252,44 +254,68 @@ template <size_t _buffer_count> class npu_scoped_timer {
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
"[%s]cnt: %llu, dur: %lluus\n",
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
(unsigned long long) sub_proc2_duration, _sub_proc_data[3].log_prefix,
(unsigned long long) _sub_proc_data[3].proc_count, (unsigned long long) sub_proc3_duration);
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration,
_sub_proc_data[0].log_prefix,
(unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration,
_sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count,
(unsigned long long) sub_proc1_duration,
_sub_proc_data[2].log_prefix,
(unsigned long long) _sub_proc_data[2].proc_count,
(unsigned long long) sub_proc2_duration,
_sub_proc_data[3].log_prefix,
(unsigned long long) _sub_proc_data[3].proc_count,
(unsigned long long) sub_proc3_duration);
break;
case 3:
DEVICE_LOG_WARN(
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n",
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration,
_sub_proc_data[0].log_prefix,
(unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration,
_sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count,
(unsigned long long) sub_proc1_duration,
_sub_proc_data[2].log_prefix,
(unsigned long long) _sub_proc_data[2].proc_count,
(unsigned long long) sub_proc2_duration);
break;
case 2:
DEVICE_LOG_WARN(
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
"[%s]cnt: %llu, dur: %lluus\n",
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration);
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration,
_sub_proc_data[0].log_prefix,
(unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration,
_sub_proc_data[1].log_prefix,
(unsigned long long) _sub_proc_data[1].proc_count,
(unsigned long long) sub_proc1_duration);
break;
case 1:
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", _log_prefix,
(unsigned long long) total_pcycles, (unsigned long long) duration,
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n",
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration,
_sub_proc_data[0].log_prefix,
(unsigned long long) _sub_proc_data[0].proc_count,
(unsigned long long) sub_proc0_duration);
break;
default:
case 0:
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", _log_prefix,
(unsigned long long) total_pcycles, (unsigned long long) duration);
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n",
_log_prefix,
(unsigned long long) total_pcycles,
(unsigned long long) duration);
break;
}
}
@ -317,8 +343,8 @@ template <size_t _buffer_count, size_t _sub_idx> class npu_sub_process_scoped_ti
}
~npu_sub_process_scoped_timer() {
_timer.add_sub_proc_cycles(_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles,
HAP_perf_get_pcycles() - _begin_pcycles);
_timer.add_sub_proc_cycles(
_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles, HAP_perf_get_pcycles() - _begin_pcycles);
}
private:

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,29 @@
#pragma once
#include "hexagon_npu.h"
#include <hexagon_types.h>
#include <cstdint>
#include "hexagon_npu.h"
namespace hexagon {
template <typename T, int N> struct HEXAGON_pack {
T val[N];
};
using HVX_Vector_x2 = std::pair<HVX_Vector, HVX_Vector>;
using HVX_VectorPair_x4 = HEXAGON_pack<HVX_VectorPair, 4>;
typedef union {
HVX_VectorPair VV;
struct {
HVX_Vector lo;
HVX_Vector hi;
} V;
} HVX_DV;
constexpr const size_t kBytesPerVector = sizeof(HVX_Vector); // 128 for v73
constexpr const size_t kAlignMask = kBytesPerVector - 1;
@ -63,67 +79,6 @@ inline void l2fetch_row(const uint8_t * row_ptr, size_t bytes) {
hexagon::l2fetch(row_ptr, kBytesPerVector, kBytesPerVector, l2fetch_vectors, 0);
}
/*
* This function converts a vector of IEEE float elements to a vector of qf32 elements
* See also: libs\qfe\inc\qhmath_hvx_convert.h
*/
inline HVX_Vector qhmath_hvx_vqf32_convert_vsf(HVX_Vector vin) {
return Q6_Vqf32_vadd_VsfVsf(vin, Q6_V_vzero());
}
/*
* This function converts a vector of IEEE half float elements to a vector of qf16 elements
* See also: libs\qfe\inc\qhmath_hvx_convert.h
*/
inline HVX_Vector qhmath_hvx_vqf16_convert_vhf(HVX_Vector vin) {
return Q6_Vqf16_vadd_VhfVhf(vin, Q6_V_vzero());
}
/*
* This function converts a pair of vectors of qf32 elements to a vector of IEEE half float elements
* See also: libs\qfe\inc\qhmath_hvx_convert.h
*/
inline HVX_Vector qhmath_hvx_vhf_convert_vqf32(HVX_VectorPair vin_vp) {
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(vin_vp));
}
/*
* This function converts a vector of qf16 elements to a pair of vectors of qf32 elements
* See also: libs\qfe\inc\qhmath_hvx_convert.h
*/
inline HVX_VectorPair qhmath_hvx_vqf32_convert_vqf16(HVX_Vector vxl) {
HVX_VectorPair vxw_vp, exponent_vp;
HVX_Vector mantissa_mask = Q6_Vh_vsplat_R(0xffe0);
HVX_Vector exp_mask = Q6_Vh_vsplat_R(0x1f);
HVX_Vector exp_offset = Q6_Vh_vsplat_R(0x70);
HVX_Vector mant32_shift = Q6_Vh_vsplat_R(0x10);
HVX_Vector reql, reqh, vxl_w, vxh_w, mantissa;
HVX_Vector el_exponent, eh_exponent;
el_exponent = Q6_V_vand_VV(exp_mask, vxl);
// Obtain the mantissa part: bits (5-15)
mantissa = Q6_V_vand_VV(mantissa_mask, vxl);
// Convert qf16 biassed exponent to qf32 biased exponent
// new exp = exp + ( 127 (qf32 bias) -15(qf16 biass) ) = 112
el_exponent = Q6_Vh_vadd_VhVh(exp_offset, el_exponent);
vxw_vp = Q6_Ww_vunpack_Vh(mantissa);
vxl_w = Q6_V_lo_W(vxw_vp);
vxh_w = Q6_V_hi_W(vxw_vp);
exponent_vp = Q6_Ww_vunpack_Vh(el_exponent);
el_exponent = Q6_V_lo_W(exponent_vp);
eh_exponent = Q6_V_hi_W(exponent_vp);
// Convert q16 mantiss to q32 mantissa
reql = Q6_Vw_vasl_VwVw(vxl_w, mant32_shift);
reqh = Q6_Vw_vasl_VwVw(vxh_w, mant32_shift);
// Add the exponent
vxl_w = Q6_Vw_vadd_VwVw(reql, el_exponent);
vxh_w = Q6_Vw_vadd_VwVw(reqh, eh_exponent);
return Q6_W_vcombine_VV(vxh_w, vxl_w);
}
template <uint32_t _TyBytes> inline void q6op_vstu_variable_ARV(void * addr, HVX_Vector vin) {
vin = Q6_V_vlalign_VVR(vin, vin, (size_t) addr); //rotate as needed.
uint32_t left_off = unaligned_bytes(addr);
@ -157,20 +112,6 @@ inline void q6op_vstu_variable_ARV(void * addr, int n, HVX_Vector vin) {
Q6_vmaskedstorenq_QAV(qL_not, (HVX_Vector *) addr, vin);
}
inline HVX_VectorPair hvx_vqf32_convert_vhf(HVX_Vector vxl) {
return qhmath_hvx_vqf32_convert_vqf16(qhmath_hvx_vqf16_convert_vhf(vxl));
}
using HVX_Vector_Dual = std::pair<HVX_Vector, HVX_Vector>;
inline HVX_Vector_Dual hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) {
HVX_VectorPair res = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vxl), one);
return {
Q6_Vsf_equals_Vqf32(Q6_V_lo_W(res)),
Q6_Vsf_equals_Vqf32(Q6_V_hi_W(res)),
};
}
inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) {
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
static_assert(kFloatsPerVector == 32, "kFloatsPerVector should be 32");
@ -247,6 +188,7 @@ inline HVX_Vector hvx_passthru(HVX_Vector src, HVX_UVector *, HVX_Vector) {
} // namespace hexagon
#include "vec_math.inl"
#include "vec_ops.inl"
namespace hexagon {
@ -313,8 +255,8 @@ inline HVX_Vector vec_dot_product_vqf32_f32_f32(const float * src0, const float
inline HVX_Vector vec_dot_product_aligned_vqf32_f32_f32(const float * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
return vec_dot_product_aligned_impl<float, HVX_Vector, vec_mpy_qf32, vec_add_qf32, vec_reduction_qf32>(src0, src1,
count);
return vec_dot_product_aligned_impl<float, HVX_Vector, vec_mpy_qf32, vec_add_qf32, vec_reduction_qf32>(
src0, src1, count);
}
inline float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) {
@ -324,23 +266,25 @@ inline float vec_dot_product_f32_f32(const float * src0, const float * src1, siz
inline float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
return vec_dot_product_aligned_impl<float, float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1,
count);
return vec_dot_product_aligned_impl<float, float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(
src0, src1, count);
}
inline bool is_f32_f32_dot_product_aligned(const float * src0, const float * src1, size_t count) {
return is_dot_product_aligned<float, float>(src0, src1, count);
}
inline HVX_Vector vec_dot_product_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
size_t count) {
inline HVX_Vector vec_dot_product_vqf16_f16_f16(const npu_device_fp16_t * src0,
const npu_device_fp16_t * src1,
size_t count) {
using namespace hexagon::vec;
return vec_dot_product_impl<npu_device_fp16_t, HVX_Vector, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16>(
src0, src1, count);
}
inline HVX_Vector vec_dot_product_aligned_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
size_t count) {
inline HVX_Vector vec_dot_product_aligned_vqf16_f16_f16(const npu_device_fp16_t * src0,
const npu_device_fp16_t * src1,
size_t count) {
using namespace hexagon::vec;
return vec_dot_product_aligned_impl<npu_device_fp16_t, HVX_Vector, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16>(
src0, src1, count);
@ -352,41 +296,68 @@ inline float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_d
src0, src1, count);
}
inline float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
size_t count) {
inline float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0,
const npu_device_fp16_t * src1,
size_t count) {
using namespace hexagon::vec;
return vec_dot_product_aligned_impl<npu_device_fp16_t, float, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(
src0, src1, count);
}
inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
size_t count) {
inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0,
const npu_device_fp16_t * src1,
size_t count) {
return is_dot_product_aligned<npu_device_fp16_t, npu_device_fp16_t>(src0, src1, count);
}
inline HVX_Vector vec_dot_product_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, HVX_Vector, hvx_vsf_convert_vhf, vec_mpy_qf32,
vec_add_qf32, vec_reduction_qf32>(src0, src1, count);
using namespace hexagon::vec::math;
return vec_dot_product_mixed_impl<npu_device_fp16_t,
float,
HVX_Vector,
hvx_vsf_convert_vhf,
vec_mpy_qf32,
vec_add_qf32,
vec_reduction_qf32>(src0, src1, count);
}
inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1,
size_t count) {
inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0,
const float * src1,
size_t count) {
using namespace hexagon::vec;
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, HVX_Vector, hvx_vsf_convert_vhf, vec_mpy_qf32,
vec_add_qf32, vec_reduction_qf32>(src0, src1, count);
using namespace hexagon::vec::math;
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t,
float,
HVX_Vector,
hvx_vsf_convert_vhf,
vec_mpy_qf32,
vec_add_qf32,
vec_reduction_qf32>(src0, src1, count);
}
inline float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
using namespace hexagon::vec::math;
return vec_dot_product_mixed_impl<npu_device_fp16_t,
float,
float,
hvx_vsf_convert_vhf,
vec_mpy_qf32,
vec_add_qf32,
vec_reduction_f32_qf32>(src0, src1, count);
}
inline float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, float, hvx_vsf_convert_vhf, vec_mpy_qf32,
vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
using namespace hexagon::vec::math;
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t,
float,
float,
hvx_vsf_convert_vhf,
vec_mpy_qf32,
vec_add_qf32,
vec_reduction_f32_qf32>(src0, src1, count);
}
inline bool is_f16_f32_dot_product_aligned(const npu_device_fp16_t * src0, const float * src1, size_t count) {
@ -409,4 +380,35 @@ _TReturn type_erase_dot_func(const void * src0, const void * src1, size_t count)
return _DotFunc(src0_typed, src1_typed, count);
}
inline HVX_Vector vec_silu_f32_f32(HVX_Vector x, HVX_VectorPair_x4 coeff) {
using namespace hexagon::vec::math;
HVX_Vector one = Q6_V_vsplat_R(0x3F800000);
// x/(1.0f + expf(-x));
HVX_Vector exp_neg_x = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(Q6_V_vzero(), x));
HVX_Vector denom = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(qhmath_hvx_exp_vf(exp_neg_x), one));
return qhmath_hvx_div_vf(x, denom, coeff);
}
inline HVX_Vector vec_silu_f16_f16(HVX_Vector x, HVX_VectorPair_x4 coeff) {
using namespace hexagon::vec::math;
HVX_Vector one = Q6_Vh_vsplat_R(0x3c00);
// x/(1.0f + expf(-x));
HVX_Vector exp_neg_x = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(Q6_V_vzero(), x));
HVX_Vector denom = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_VhfVhf(qhmath_hvx_exp_vhf(exp_neg_x), one));
return qhmath_hvx_div_vhf(x, denom, coeff);
}
inline HVX_Vector vec_swiglu_f32_f32(HVX_Vector x, HVX_Vector g, HVX_VectorPair_x4 coeff) {
HVX_Vector silu = vec_silu_f32_f32(x, coeff);
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(silu, g));
}
inline HVX_Vector vec_swiglu_f16_f16(HVX_Vector x, HVX_Vector g, HVX_VectorPair_x4 coeff) {
HVX_Vector silu = vec_silu_f16_f16(x, coeff);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(silu, g));
}
} // namespace hexagon

View File

@ -1,15 +1,18 @@
#pragma once
#include "hexagon_npu.h"
#include <hexagon_types.h>
#include <cstdint>
#include "hexagon_npu.h"
namespace hexagon::vec {
template <typename _TElem, typename _TRet, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), _TRet (*_ReduceFunc)(HVX_Vector)>
template <typename _TElem,
typename _TRet,
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
_TRet (*_ReduceFunc)(HVX_Vector)>
inline _TRet vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) {
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
@ -95,8 +98,11 @@ inline _TRet vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size
return _ReduceFunc(sum);
}
template <typename _TElem, typename _TRet, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), _TRet (*_ReduceFunc)(HVX_Vector)>
template <typename _TElem,
typename _TRet,
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
_TRet (*_ReduceFunc)(HVX_Vector)>
inline _TRet vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) {
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
@ -170,8 +176,12 @@ inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) {
return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result);
}
template <typename _TElem0, typename _TElem1, typename _TRet, HVX_Vector_Dual (*_ExpandFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
template <typename _TElem0,
typename _TElem1,
typename _TRet,
HVX_Vector_x2 (*_ExpandFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
_TRet (*_ReduceFunc)(HVX_Vector)>
inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
@ -202,8 +212,8 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr
HVX_Vector curr0 = src0_vec_ptr[0];
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV);
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector_x2 s0_pair = _ExpandFunc(s0, kOneV);
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
sum0 = _AddFunc(_MpyFunc(s0_pair.first, l1), sum0);
@ -228,8 +238,8 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV);
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector_x2 s0_pair = _ExpandFunc(s0, kOneV);
const bool has_remaining_src1_vector = src1_vec_ptr_end - src1_vec_ptr > 0;
if (has_remaining_src1_vector) {
@ -262,7 +272,7 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr
prev1;
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
HVX_Vector_Dual curr0_pair = _ExpandFunc(curr0, kOneV);
HVX_Vector_x2 curr0_pair = _ExpandFunc(curr0, kOneV);
curr0 = leftover1 == leftover0 ? curr0_pair.first : curr0_pair.second;
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes1), sum);
@ -271,8 +281,12 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr
return _ReduceFunc(sum);
}
template <typename _TElem0, typename _TElem1, typename _TRet, HVX_Vector_Dual (*_ExpandFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
template <typename _TElem0,
typename _TElem1,
typename _TRet,
HVX_Vector_x2 (*_ExpandFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
_TRet (*_ReduceFunc)(HVX_Vector)>
inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
@ -297,16 +311,16 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem
HVX_Vector sum3 = Q6_V_vzero();
do {
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
HVX_Vector_Dual curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV);
HVX_VectorPair curr10 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
sum0 = _AddFunc(_MpyFunc(curr00.first, Q6_V_lo_W(curr10)), sum0);
sum1 = _AddFunc(_MpyFunc(curr00.second, Q6_V_hi_W(curr10)), sum1);
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
HVX_Vector_x2 curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV);
HVX_VectorPair curr10 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
sum0 = _AddFunc(_MpyFunc(curr00.first, Q6_V_lo_W(curr10)), sum0);
sum1 = _AddFunc(_MpyFunc(curr00.second, Q6_V_hi_W(curr10)), sum1);
HVX_Vector_Dual curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV);
HVX_VectorPair curr11 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[1];
sum2 = _AddFunc(_MpyFunc(curr01.first, Q6_V_lo_W(curr11)), sum2);
sum3 = _AddFunc(_MpyFunc(curr01.second, Q6_V_hi_W(curr11)), sum3);
HVX_Vector_x2 curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV);
HVX_VectorPair curr11 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[1];
sum2 = _AddFunc(_MpyFunc(curr01.first, Q6_V_lo_W(curr11)), sum2);
sum3 = _AddFunc(_MpyFunc(curr01.second, Q6_V_hi_W(curr11)), sum3);
src0_vec_ptr += 2;
src1_vec_ptr += 4;
@ -317,8 +331,8 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem
}
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
HVX_Vector curr0 = src0_vec_ptr[0];
HVX_Vector_Dual s0_pair = _ExpandFunc(curr0, kOneV);
HVX_Vector curr0 = src0_vec_ptr[0];
HVX_Vector_x2 s0_pair = _ExpandFunc(curr0, kOneV);
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
sum0 = _AddFunc(_MpyFunc(s0_pair.first, Q6_V_lo_W(curr1)), sum0);
@ -328,7 +342,8 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem
return _ReduceFunc(_AddFunc(sum0, sum1));
}
template <HVX_Vector (*_Func)(HVX_Vector, HVX_UVector *, HVX_Vector), HVX_Vector (*_FuncScaleConvert)(float),
template <HVX_Vector (*_Func)(HVX_Vector, HVX_UVector *, HVX_Vector),
HVX_Vector (*_FuncScaleConvert)(float),
typename _TParam>
inline void vec_scale_impl(const _TParam * src, float scale, _TParam * dst, size_t count) {
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TParam);
@ -410,7 +425,7 @@ template <typename _TData> inline void vec_zero_impl(_TData * src, size_t count)
}
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector), typename _TyData>
inline void vec_trans_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) {
inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData * dst, size_t count) {
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
@ -496,4 +511,95 @@ inline void vec_trans_op_impl(const _TyData * src0, const _TyData * src1, size_t
}
}
template <typename _TyData, typename _TyParam, HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector, _TyParam)>
inline void vec_trans_with_param_impl(const _TyData * src0,
const _TyData * src1,
_TyData * dst,
size_t count,
_TyParam param) {
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
HVX_Vector * dst_vec_ptr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
HVX_Vector prev0 = *src0_vec_ptr++;
HVX_Vector prev1 = *src1_vec_ptr++;
{
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
dst_vec_ptr[0] = _OpBinaryTransform(l0, l1, param);
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
dst_vec_ptr[1] = _OpBinaryTransform(h0, h1, param);
prev0 = Q6_V_hi_W(curr0);
prev1 = Q6_V_hi_W(curr1);
src0_vec_ptr += 2;
src1_vec_ptr += 2;
dst_vec_ptr += 2;
}
}
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
HVX_Vector curr0 = *src0_vec_ptr++;
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector curr1 = *src1_vec_ptr++;
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param);
prev0 = curr0;
prev1 = curr1;
dst_vec_ptr++;
}
const size_t leftover = count % kElementsPerVector;
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
// handle the last vector
// see also:
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param);
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
prev0 = curr0;
prev1 = curr1;
dst_vec_ptr++;
}
if (leftover > 0) {
// handle the leftover elements
const size_t leftover_bytes = leftover * sizeof(_TyData);
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
*src0_vec_ptr :
prev0;
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
*src1_vec_ptr :
prev1;
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1, param));
}
}
} // namespace hexagon::vec

View File

@ -1,10 +1,10 @@
#pragma once
#include "util.hpp"
#include <HAP_compute_res.h>
#include <HAP_vtcm_mgr.h>
#include "util.hpp"
namespace hexagon {
class vtcm_mem {

View File

@ -5,13 +5,13 @@
#include <domain_default.h>
#pragma GCC diagnostic pop
#include "graph.hpp"
#include "util.hpp"
#include <remote.h>
#include <type_traits>
#include "graph.hpp"
#include "util.hpp"
#define SKEL_URI_DEFINE(arch) ("file:///libhexagon_npu_skel_" arch ".so?npu_device_skel_handle_invoke&_modver=1.0")
namespace {
@ -68,8 +68,10 @@ void backend_free(ggml_backend_t backend) {
delete get_backend_object(backend);
}
bool backend_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src,
ggml_tensor * dst) {
bool backend_cpy_tensor_async(ggml_backend_t backend_src,
ggml_backend_t backend_dst,
const ggml_tensor * src,
ggml_tensor * dst) {
// TODO: implement this
return false;
}
@ -194,13 +196,24 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) {
boolean supported = false;
auto dst_spec = get_spec(op);
auto ret = npu_device_device_support_op(_device_handle, npu_op, &dst_spec, srcs, i, &supported);
npu_device_tensor_op_spec npu_op_spec = { npu_op, {} };
static_assert(sizeof(npu_op_spec.params) <= sizeof(op->op_params),
"npu_op_spec.params size should less than op->op_params size");
memcpy(npu_op_spec.params, op->op_params, sizeof(npu_op_spec.params));
auto ret = npu_device_device_support_op(_device_handle, &npu_op_spec, &dst_spec, srcs, i, &supported);
if (ret != AEE_SUCCESS || !supported) {
#ifndef NDEBUG
auto * src0_type = i ? ggml_type_name(op->src[0]->type) : "null";
auto * src1_type = (i > 1) ? ggml_type_name(op->src[1]->type) : "null";
LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", get_name(), ggml_op_desc(op),
ggml_type_name(op->type), src0_type, src1_type, ret, supported);
LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n",
get_name(),
ggml_op_desc(op),
ggml_type_name(op->type),
src0_type,
src1_type,
ret,
supported);
#endif
return false;
}
@ -244,8 +257,8 @@ bool npu_device::init_device_lib() {
}
if (err != AEE_SUCCESS) {
LOG_ERROR("[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err,
device_lib_uri.c_str());
LOG_ERROR(
"[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err, device_lib_uri.c_str());
_device_handle = 0;
return false;
}
@ -275,16 +288,26 @@ bool npu_device::supports_op(const ggml_tensor * op) {
if (op->op != GGML_OP_NONE && op->op != GGML_OP_VIEW && op->op != GGML_OP_RESHAPE &&
op->op != GGML_OP_PERMUTE) {
_supported_op++;
LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op),
ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load());
LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n",
get_name(),
ggml_op_desc(op),
ggml_get_name(op),
op_desc,
_supported_op.load(),
_unsupported_op.load());
}
return true;
}
_unsupported_op++;
LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op),
ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load());
LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n",
get_name(),
ggml_op_desc(op),
ggml_get_name(op),
op_desc,
_supported_op.load(),
_unsupported_op.load());
return false;
}
#else

View File

@ -35,6 +35,8 @@ enum npu_device_tensor_op op_to_npu_op(ggml_op op) {
return NPU_OP_FLASH_ATTN;
case GGML_OP_ROPE:
return NPU_OP_ROPE;
case GGML_OP_GLU:
return NPU_OP_GLU;
default:
return NPU_OP_COUNT;
}
@ -56,6 +58,8 @@ const char * get_npu_op_desc(enum npu_device_tensor_op op) {
return ggml_op_name(GGML_OP_FLASH_ATTN_EXT);
case NPU_OP_ROPE:
return ggml_op_name(GGML_OP_ROPE);
case NPU_OP_GLU:
return ggml_op_name(GGML_OP_GLU);
default:
return "UNKNOWN";
}

View File

@ -17,6 +17,7 @@ interface npu_device : remote_handle64{
typedef int64_t ne_type[DEVICE_TENSOR_MAX_DIMS];
typedef uint64_t nb_type[DEVICE_TENSOR_MAX_DIMS];
typedef int32_t param_type[DEVICE_TENSOR_MAX_OP_PARAMS];
typedef uint64_t tensor_handle_t;
typedef uint64_t graph_handle_t;
@ -50,9 +51,19 @@ interface npu_device : remote_handle64{
NPU_OP_RMS_NORM,
NPU_OP_FLASH_ATTN,
NPU_OP_ROPE,
NPU_OP_GLU,
NPU_OP_COUNT
};
enum glu_op {
NPU_GLU_OP_REGLU,
NPU_GLU_OP_GEGLU,
NPU_GLU_OP_SWIGLU,
NPU_GLU_OP_GEGLU_ERF,
NPU_GLU_OP_GEGLU_QUICK,
NPU_GLU_OP_COUNT
};
enum tensor_data_type {
NPU_DATA_TYPE_F32,
NPU_DATA_TYPE_F16,
@ -69,9 +80,14 @@ interface npu_device : remote_handle64{
tensor_data_type type;
};
struct tensor_op_spec {
tensor_op op;
param_type params;
};
struct tensor_update_config {
tensor_op op;
int32_t params[DEVICE_TENSOR_MAX_OP_PARAMS];
param_type params;
tensor_handle_t src_handles[DEVICE_TENSOR_MAX_SRC];
};
@ -90,7 +106,7 @@ interface npu_device : remote_handle64{
);
AEEResult device_support_op(
in tensor_op op,
in tensor_op_spec op_spec,
in tensor_spec dst,
in sequence<tensor_spec> srcs,
rout boolean is_supported