feat: perf opt gemv (#54)
* add GEMV implementation for matrix multiplication in hexagon * refactor: optimize GEMV implementation for matrix multiplication in hexagon * wip * refactor: enhance caching mechanism in GEMV implementation for matrix multiplication * wip * refactor: streamline caching logic in GEMV implementation for matrix multiplication * wip * wip * fix broadcase in flash_attn * format * refactor: optimize memory fetching in matrix multiplication implementations * wip * fix aligned gemv * rename * refactor: remove unused memory cache functions and initialize VTCM cache * wip * feat: add vector math functions for IEEE float and half float operations * feat: add vec_silu_f32 and vec_silu_f16 functions for SiLU activation * feat: implement GLU operation support in tensor processing * feat: add GLU operation support and related enhancements in tensor processing * wip * wip * wip * feat: add qhmath_hvx_div_vf functions for f32 vector operations * feat: add qhmath_hvx_div_vhf functions for f16 vector operations * fix: reorder parameters in vector operation functions for consistency * wip * feat: enhance vector operations with parameterized transformations and improved GLU implementations * wip * fix: increase default stack size and correct thread parameter indexing in thread pool * fix f16 div * fix f32 div * fix: update GLU vector operations to use explicit denominator calculation * wip * wip * Refactor cacheability check for matrix multiplication to handle multiple source tensors * Revert "fix: increase default stack size and correct thread parameter indexing in thread pool" This reverts commit 40e3f0974dbb04051aa30b397a9a171c6dd32678. * wip * fix comments * replace copy with memcpy
This commit is contained in:
parent
6260c3166d
commit
379bdeb18c
|
|
@ -1,10 +1,4 @@
|
|||
|
||||
#include <AEEStdErr.h>
|
||||
#include <HAP_compute_res.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "graph.hpp"
|
||||
#include "hexagon_npu.h"
|
||||
#include "op_impl.hpp"
|
||||
|
|
@ -14,6 +8,12 @@
|
|||
#include "type_traits.hpp"
|
||||
#include "util.hpp"
|
||||
|
||||
#include <AEEStdErr.h>
|
||||
#include <HAP_compute_res.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace {
|
||||
|
||||
struct npu_device_context {
|
||||
|
|
@ -130,8 +130,12 @@ AEEResult npu_device_device_get_alignment(remote_handle64 _h, uint32_t * alignme
|
|||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, int srcsLen, boolean * is_supported) {
|
||||
AEEResult npu_device_device_support_op(remote_handle64 _h,
|
||||
const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
int srcsLen,
|
||||
boolean * is_supported) {
|
||||
NPU_UNUSED(_h);
|
||||
|
||||
if (!srcs || srcsLen <= 0 || !dst || !is_supported) {
|
||||
|
|
@ -139,19 +143,21 @@ AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op
|
|||
return AEE_EINVARGS;
|
||||
}
|
||||
|
||||
*is_supported = hexagon::support_op(op, dst, srcs, srcsLen);
|
||||
*is_supported = hexagon::support_op(op_spec, dst, srcs, srcsLen);
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult npu_device_tensor_init(remote_handle64 _h, const npu_device_tensor_config * info,
|
||||
npu_device_tensor_handle_t * tensor_handle) {
|
||||
AEEResult npu_device_tensor_init(remote_handle64 _h,
|
||||
const npu_device_tensor_config * info,
|
||||
npu_device_tensor_handle_t * tensor_handle) {
|
||||
NPU_UNUSED(_h);
|
||||
auto * tensor = new hexagon::tensor(*info);
|
||||
*tensor_handle = tensor_to_handle(tensor);
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult npu_device_tensor_update_params(remote_handle64 _h, npu_device_tensor_handle_t tensor_handle,
|
||||
AEEResult npu_device_tensor_update_params(remote_handle64 _h,
|
||||
npu_device_tensor_handle_t tensor_handle,
|
||||
const npu_device_tensor_update_config * config) {
|
||||
NPU_UNUSED(_h);
|
||||
auto * tensor = tensor_from_handle(tensor_handle);
|
||||
|
|
@ -174,8 +180,9 @@ AEEResult npu_device_tensor_free(remote_handle64 _h, npu_device_tensor_handle_t
|
|||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult npu_device_tensors_free(remote_handle64 _h, const npu_device_tensor_handle_t * tensor_handles,
|
||||
int tensor_handlesLen) {
|
||||
AEEResult npu_device_tensors_free(remote_handle64 _h,
|
||||
const npu_device_tensor_handle_t * tensor_handles,
|
||||
int tensor_handlesLen) {
|
||||
NPU_UNUSED(_h);
|
||||
if (!tensor_handles || tensor_handlesLen < 0) {
|
||||
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments");
|
||||
|
|
@ -201,8 +208,10 @@ AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t *
|
|||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult npu_device_graph_set_tensor(remote_handle64 _h, npu_device_graph_handle_t graph_handle,
|
||||
const npu_device_tensor_handle_t * tensor_handles, int tensor_handlesLen) {
|
||||
AEEResult npu_device_graph_set_tensor(remote_handle64 _h,
|
||||
npu_device_graph_handle_t graph_handle,
|
||||
const npu_device_tensor_handle_t * tensor_handles,
|
||||
int tensor_handlesLen) {
|
||||
NPU_UNUSED(_h);
|
||||
auto * graph = graph_from_handle(graph_handle);
|
||||
if (!graph || !tensor_handles || tensor_handlesLen <= 0) {
|
||||
|
|
@ -213,7 +222,8 @@ AEEResult npu_device_graph_set_tensor(remote_handle64 _h, npu_device_graph_handl
|
|||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h, npu_device_graph_handle_t graph_handle,
|
||||
AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h,
|
||||
npu_device_graph_handle_t graph_handle,
|
||||
const npu_device_tensor_handle_t * tensor_handles,
|
||||
int tensor_handlesLen,
|
||||
const npu_device_tensor_update_config * tensor_params,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
|
||||
#include "graph.hpp"
|
||||
|
||||
#include <new>
|
||||
|
||||
#include "op_impl.hpp"
|
||||
#include "util.hpp"
|
||||
#include "vtcm_mem.hpp"
|
||||
|
||||
#include <new>
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
graph::graph() noexcept {
|
||||
|
|
@ -30,8 +30,12 @@ void graph::set_tensor(const npu_device_tensor_handle_t * tensors, int tensor_co
|
|||
for (int i = 0; i < tensor_count; ++i) {
|
||||
auto * tensor_obj = reinterpret_cast<tensor *>(tensors[i]);
|
||||
_tensors[i] = tensor_obj;
|
||||
DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n", (void *) this, i, (void *) tensor_obj,
|
||||
(void *) tensor_obj->get_src(0), (void *) tensor_obj->get_src(1),
|
||||
DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n",
|
||||
(void *) this,
|
||||
i,
|
||||
(void *) tensor_obj,
|
||||
(void *) tensor_obj->get_src(0),
|
||||
(void *) tensor_obj->get_src(1),
|
||||
op_get_name(tensor_obj->get_op()));
|
||||
}
|
||||
|
||||
|
|
@ -64,8 +68,9 @@ bool graph::compute(default_thread_pool * thread_pool, const float * f16_to_f32_
|
|||
return true;
|
||||
}
|
||||
|
||||
void graph::thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
|
||||
void * graph) {
|
||||
void graph::thread_pool_task(default_thread_pool * pool,
|
||||
default_thread_pool::thread_params * thread_params,
|
||||
void * graph) {
|
||||
reinterpret_cast<hexagon::graph *>(graph)->compute_impl(pool, thread_params);
|
||||
}
|
||||
|
||||
|
|
@ -86,8 +91,11 @@ void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread
|
|||
|
||||
const bool should_sync = requires_thread_barrier(op);
|
||||
if (pool && should_sync && i < _tensor_count - 1) {
|
||||
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this,
|
||||
params.get_thread_index(), i, _tensor_count);
|
||||
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]",
|
||||
(void *) this,
|
||||
params.get_thread_index(),
|
||||
i,
|
||||
_tensor_count);
|
||||
pool->sync_thread();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
#include "tensor.hpp"
|
||||
#include "thread_pool.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
class graph {
|
||||
|
|
@ -20,8 +20,9 @@ class graph {
|
|||
bool compute(default_thread_pool * thread_pool, const float * f16_to_f32_table);
|
||||
|
||||
private:
|
||||
static void thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
|
||||
void * graph);
|
||||
static void thread_pool_task(default_thread_pool * pool,
|
||||
default_thread_pool::thread_params * thread_params,
|
||||
void * graph);
|
||||
void compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params);
|
||||
|
||||
std::unique_ptr<tensor *[]> _tensors;
|
||||
|
|
|
|||
|
|
@ -14,15 +14,20 @@ inline float f16_to_f32(const npu_device_fp16_t src) {
|
|||
|
||||
// From: ggml/src/ggml-cpu/ops.cpp
|
||||
template <bool _IsKvF16>
|
||||
void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k,
|
||||
const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) {
|
||||
void flash_attn_impl(hexagon::tensor * out,
|
||||
const hexagon::tensor * q,
|
||||
const hexagon::tensor * k,
|
||||
const hexagon::tensor * v,
|
||||
const hexagon::tensor * mask,
|
||||
hexagon::compute_params * params) {
|
||||
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");
|
||||
|
||||
constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32;
|
||||
|
||||
if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n",
|
||||
hexagon::get_type_name(k->get_type()), hexagon::get_type_name(v->get_type()));
|
||||
hexagon::get_type_name(k->get_type()),
|
||||
hexagon::get_type_name(v->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -80,7 +85,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return;
|
||||
}
|
||||
|
|
@ -118,7 +124,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
|
||||
const npu_device_fp16_t * mp =
|
||||
mask_ptr ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1) +
|
||||
(iq3 % mask->get_ne(2)) * mask->get_nb(2)) :
|
||||
(iq2 % mask->get_ne(2)) * mask->get_nb(2) +
|
||||
(iq3 % mask->get_ne(3)) * mask->get_nb(3)) :
|
||||
nullptr;
|
||||
|
||||
// k indices
|
||||
|
|
@ -251,8 +258,8 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
|
|||
const auto * v = out->get_src(2);
|
||||
const auto * mask = out->get_src(3);
|
||||
if (!q || !k || !v || !mask) {
|
||||
DEVICE_LOG_DEBUG("invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v,
|
||||
(void *) mask);
|
||||
DEVICE_LOG_DEBUG(
|
||||
"invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v, (void *) mask);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -264,8 +271,11 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len) {
|
||||
bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_FLASH_ATTN) {
|
||||
DEVICE_LOG_DEBUG("op is not NPU_OP_FLASH_ATTN: %d\n", op);
|
||||
return false;
|
||||
|
|
@ -295,7 +305,9 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
|
|||
|
||||
const auto * v = &srcs[2];
|
||||
if (v->type != k->type) { // TODO: support more v types
|
||||
DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n", op_get_name(op), get_type_name(v->type),
|
||||
DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n",
|
||||
op_get_name(op),
|
||||
get_type_name(v->type),
|
||||
get_type_name(k->type));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -310,28 +322,42 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
|
|||
DEVICE_LOG_DEBUG(
|
||||
"[%s]dst shape does not match q and v: dst ne: %ld, %ld, %ld, %ld, q ne: %ld, %ld, %ld, %ld, "
|
||||
"v ne: %ld, %ld, %ld, %ld\n",
|
||||
op_get_name(op), dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3],
|
||||
v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||
op_get_name(op),
|
||||
dst->ne[0],
|
||||
dst->ne[1],
|
||||
dst->ne[2],
|
||||
dst->ne[3],
|
||||
q->ne[0],
|
||||
q->ne[1],
|
||||
q->ne[2],
|
||||
q->ne[3],
|
||||
v->ne[0],
|
||||
v->ne[1],
|
||||
v->ne[2],
|
||||
v->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_transposed_or_permuted(dst->nb)) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n", op_get_name(op),
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
|
||||
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n",
|
||||
op_get_name(op),
|
||||
dst->nb[0],
|
||||
dst->nb[1],
|
||||
dst->nb[2],
|
||||
dst->nb[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (q->ne[0] != k->ne[0]) {
|
||||
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
|
||||
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
|
||||
k->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (q->ne[2] != k->ne[2] || q->ne[3] != k->ne[3] || q->ne[3] != 1) {
|
||||
// TODO: add broadcast support
|
||||
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
|
||||
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
|
||||
op_get_name(op),
|
||||
q->ne[0],
|
||||
q->ne[1],
|
||||
q->ne[2],
|
||||
q->ne[3],
|
||||
k->ne[0],
|
||||
k->ne[1],
|
||||
k->ne[2],
|
||||
k->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@
|
|||
namespace hexagon {
|
||||
|
||||
bool flash_attn_f32(tensor * out, compute_params * params);
|
||||
bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len);
|
||||
bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -2,20 +2,20 @@
|
|||
|
||||
#include "op_impl.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "op_flash_attn.hpp"
|
||||
#include "op_mul_mat.hpp"
|
||||
#include "op_rope.hpp"
|
||||
#include "type_traits.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace {
|
||||
|
||||
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
|
||||
inline void vec_op_f32_f32(const float * src0, const float * src1, size_t count, float * dst) {
|
||||
inline void vec_op_f32_f32(const float * src0, const float * src1, float * dst, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_trans_op_impl<_OpBinaryTransform, float>(src0, src1, count, dst);
|
||||
vec_trans_impl<_OpBinaryTransform, float>(src0, src1, dst, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vadd_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
|
|
@ -31,10 +31,12 @@ inline HVX_Vector vmul_f32_f32(HVX_Vector a, HVX_Vector b) {
|
|||
}
|
||||
|
||||
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
|
||||
inline void vec_op_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count,
|
||||
npu_device_fp16_t * dst) {
|
||||
inline void vec_op_f16_f16(const npu_device_fp16_t * src0,
|
||||
const npu_device_fp16_t * src1,
|
||||
npu_device_fp16_t * dst,
|
||||
size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_trans_op_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, count, dst);
|
||||
vec_trans_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, dst, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) {
|
||||
|
|
@ -53,12 +55,17 @@ inline HVX_Vector vmul_f16_f16(HVX_Vector a, HVX_Vector b) {
|
|||
|
||||
template <typename T> struct get_data_type {};
|
||||
|
||||
template <typename _TyData> struct get_data_type<void (*)(const _TyData *, const _TyData *, size_t, _TyData *)> {
|
||||
template <typename _TyData> struct get_data_type<void (*)(const _TyData *, const _TyData *, _TyData *, size_t)> {
|
||||
using type = _TyData;
|
||||
};
|
||||
|
||||
template <typename _TyData>
|
||||
struct get_data_type<void (*)(const _TyData *, const _TyData *, _TyData *, size_t, hexagon::HVX_VectorPair_x4)> {
|
||||
using type = _TyData;
|
||||
};
|
||||
|
||||
template <typename _TyData, typename _TyParam>
|
||||
struct get_data_type<void (*)(const _TyData *, size_t, _TyParam, _TyData *)> {
|
||||
struct get_data_type<void (*)(const _TyData *, _TyData *, size_t, _TyParam)> {
|
||||
using type = _TyData;
|
||||
using param_type = typename std::remove_cv<typename std::remove_reference<_TyData>::type>::type;
|
||||
};
|
||||
|
|
@ -85,7 +92,8 @@ template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::co
|
|||
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -119,16 +127,21 @@ template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::co
|
|||
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes);
|
||||
}
|
||||
|
||||
_RowFunc(reinterpret_cast<const data_type *>(src0_row), reinterpret_cast<const data_type *>(src1_row),
|
||||
static_cast<size_t>(out->get_ne(0)), reinterpret_cast<data_type *>(dst_row));
|
||||
_RowFunc(reinterpret_cast<const data_type *>(src0_row),
|
||||
reinterpret_cast<const data_type *>(src1_row),
|
||||
reinterpret_cast<data_type *>(dst_row),
|
||||
static_cast<size_t>(out->get_ne(0)));
|
||||
}
|
||||
|
||||
out->release_write_buffer(); // mark the output tensor as modified
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len) {
|
||||
bool is_element_wise_op_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_ADD && op != NPU_OP_SUB && op != NPU_OP_MUL) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
|
|
@ -142,27 +155,31 @@ bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tens
|
|||
const auto & src0 = srcs[0];
|
||||
const auto & src1 = srcs[1];
|
||||
if (dst->type != src0.type || dst->type != src1.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
|
||||
hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: fix FP16 add/sub
|
||||
if (dst->type == NPU_DATA_TYPE_F16 && op != NPU_OP_MUL) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0.ne[0] != src1.ne[0]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.ne[0] and src1.ne[0] not match: %ld vs %ld\n", hexagon::op_get_name(op),
|
||||
(long) src0.ne[0], (long) src1.ne[0]);
|
||||
DEVICE_LOG_DEBUG("[%s]src0.ne[0] and src1.ne[0] not match: %ld vs %ld\n",
|
||||
hexagon::op_get_name(op),
|
||||
(long) src0.ne[0],
|
||||
(long) src1.ne[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -174,7 +191,7 @@ bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tens
|
|||
return true;
|
||||
}
|
||||
|
||||
void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
|
||||
void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
|
||||
HVX_Vector * src_vec_ptr = ((HVX_Vector *) src);
|
||||
|
|
@ -206,7 +223,7 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
|
|||
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
|
||||
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum,
|
||||
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes));
|
||||
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes));
|
||||
}
|
||||
|
||||
const float mean = hexagon::vec_reduction_f32_qf32(sum) / count; // TODO: figure out how to do division in vector
|
||||
|
|
@ -231,7 +248,8 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
|
|||
|
||||
auto * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -259,16 +277,21 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
|
|||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
|
||||
}
|
||||
|
||||
_RowFunc(reinterpret_cast<const data_type *>(src0_row), static_cast<size_t>(out->get_ne(0)), param,
|
||||
reinterpret_cast<data_type *>(dst_row));
|
||||
_RowFunc(reinterpret_cast<const data_type *>(src0_row),
|
||||
reinterpret_cast<data_type *>(dst_row),
|
||||
static_cast<size_t>(out->get_ne(0)),
|
||||
param);
|
||||
}
|
||||
|
||||
out->release_write_buffer(); // mark the output tensor as modified
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len) {
|
||||
bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_RMS_NORM) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
|
|
@ -281,14 +304,16 @@ bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec
|
|||
|
||||
const auto & src0 = srcs[0];
|
||||
if (dst->type != src0.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
|
||||
hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -300,6 +325,171 @@ bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec
|
|||
return true;
|
||||
}
|
||||
|
||||
inline void glu_vec_op_f32_f32(const float * src0,
|
||||
const float * src1,
|
||||
float * dst,
|
||||
size_t count,
|
||||
hexagon::HVX_VectorPair_x4 coeff) {
|
||||
using namespace hexagon::vec;
|
||||
vec_trans_with_param_impl<float, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f32_f32>(
|
||||
src0, src1, dst, count, coeff);
|
||||
}
|
||||
|
||||
inline void glu_vec_op_f16_f16(const npu_device_fp16_t * src0,
|
||||
const npu_device_fp16_t * src1,
|
||||
npu_device_fp16_t * dst,
|
||||
size_t count,
|
||||
hexagon::HVX_VectorPair_x4 coeff) {
|
||||
using namespace hexagon::vec;
|
||||
vec_trans_with_param_impl<npu_device_fp16_t, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f16_f16>(
|
||||
src0, src1, dst, count, coeff);
|
||||
}
|
||||
|
||||
template <auto _GluRowFunc, hexagon::HVX_VectorPair_x4 (*_CoeffLoadFunc)()>
|
||||
bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) {
|
||||
using data_type = typename get_data_type<decltype(_GluRowFunc)>::type;
|
||||
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "element_wise_op requires max dims 4");
|
||||
|
||||
if (!out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool has_src1 = out->get_src(1) != nullptr;
|
||||
auto * src0 = out->get_src(0);
|
||||
auto * src1 = has_src1 ? out->get_src(1) : src0;
|
||||
if (!src0 || !src1) {
|
||||
return true; // skip if no src
|
||||
}
|
||||
|
||||
const auto total_cols = has_src1 ? src0->get_ne(0) : src0->get_ne(0) / 2;
|
||||
if (out->get_ne(0) != total_cols) {
|
||||
DEVICE_LOG_ERROR("out.ne[0] (%ld) != total_cols (%d)\n", (long) out->get_ne(0), (int) total_cols);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
|
||||
const auto rows_per_cube = out->get_ne(2) * out->get_ne(1);
|
||||
const auto start_end = params->get_work_slice(total_rows);
|
||||
if (start_end.first >= start_end.second) {
|
||||
return true;
|
||||
}
|
||||
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
||||
const int32_t swapped = out->get_op_param<int32_t>(1);
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = has_src1 ? src1->get_read_buffer() : (src0_ptr + total_cols * sizeof(data_type));
|
||||
if (swapped) {
|
||||
std::swap(src0_ptr, src1_ptr);
|
||||
}
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
|
||||
|
||||
auto coeff = _CoeffLoadFunc();
|
||||
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type);
|
||||
for (int64_t ir = start_end.first; ir < start_end.second; ++ir) {
|
||||
const auto i03 = ir / rows_per_cube;
|
||||
const auto i02 = ir / out->get_ne(1) - i03 * out->get_ne(2);
|
||||
const auto i01 = ir % out->get_ne(1); // TODO: should we use divide instead of mod?
|
||||
const auto i13 = i03 % src1->get_ne(3);
|
||||
const auto i12 = i02 % src1->get_ne(2);
|
||||
const auto i11 = i01 % src1->get_ne(1);
|
||||
|
||||
auto * src1_plane = src1_ptr + i13 * src1->get_nb(3) + i12 * src1->get_nb(2);
|
||||
auto * src0_row = src0_ptr + i03 * src0->get_nb(3) + i02 * src0->get_nb(2) + i01 * src0->get_nb(1);
|
||||
auto * src1_row = src1_plane + i11 * src1->get_nb(1);
|
||||
auto * dst_row = dst_ptr + i03 * out->get_nb(3) + i02 * out->get_nb(2) + i01 * out->get_nb(1);
|
||||
if (ir + 1 < start_end.second) {
|
||||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
|
||||
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes);
|
||||
}
|
||||
|
||||
_GluRowFunc(reinterpret_cast<const data_type *>(src0_row),
|
||||
reinterpret_cast<const data_type *>(src1_row),
|
||||
reinterpret_cast<data_type *>(dst_row),
|
||||
static_cast<size_t>(total_cols),
|
||||
coeff);
|
||||
}
|
||||
|
||||
out->release_write_buffer(); // mark the output tensor as modified
|
||||
return true;
|
||||
}
|
||||
|
||||
template <npu_device_tensor_data_type _DataType>
|
||||
bool glu_compute(hexagon::tensor * out, hexagon::compute_params * params) {
|
||||
using namespace hexagon::vec::math;
|
||||
|
||||
if (out->get_op_param<int32_t>(0) != NPU_GLU_OP_SWIGLU) {
|
||||
DEVICE_LOG_ERROR("Invalid GLU op type: %d\n", out->get_op_param<int32_t>(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (out->get_type() != _DataType) {
|
||||
DEVICE_LOG_ERROR("GLU op type mismatch: %s vs %s\n",
|
||||
hexagon::get_type_name(out->get_type()),
|
||||
hexagon::get_type_name(_DataType));
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr (_DataType == NPU_DATA_TYPE_F32) {
|
||||
return glu_impl<glu_vec_op_f32_f32, qhmath_load_div_sf_ltu>(out, params);
|
||||
} else if constexpr (_DataType == NPU_DATA_TYPE_F16) {
|
||||
return glu_impl<glu_vec_op_f16_f16, qhmath_load_div_hf_ltu>(out, params);
|
||||
}
|
||||
|
||||
DEVICE_LOG_ERROR("Unsupported GLU data type: %s\n", hexagon::get_type_name(out->get_type()));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_GLU) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (op_spec->params[0] != NPU_GLU_OP_SWIGLU) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported GLU op type: %d\n", hexagon::op_get_name(op), op_spec->params[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!dst || !srcs || src_len < 1) {
|
||||
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto & src0 = srcs[0];
|
||||
if (dst->type != src0.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
|
||||
hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_same_shape(src0, *dst)) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
return false; // TODO: fix: for some input hexagon intrinsics will generate nan instead of inf.
|
||||
}
|
||||
|
||||
struct op_capabilities {
|
||||
npu_device_tensor_op op;
|
||||
hexagon::op_is_supported_func_type is_supported;
|
||||
|
|
@ -341,7 +531,7 @@ constexpr const op_capabilities kOpCapabilities[] = {
|
|||
{
|
||||
unary_op<rms_norm_vec_f32>, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, false, // requires_thread_barrier
|
||||
}, false, // requires_thread_barrier
|
||||
},
|
||||
{
|
||||
NPU_OP_FLASH_ATTN,hexagon::is_flash_attn_supported,
|
||||
|
|
@ -357,6 +547,13 @@ constexpr const op_capabilities kOpCapabilities[] = {
|
|||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, false, // requires_thread_barrier
|
||||
},
|
||||
{
|
||||
NPU_OP_GLU, is_glu_op_supported,
|
||||
{
|
||||
glu_compute<NPU_DATA_TYPE_F32>, // NPU_DATA_TYPE_F32
|
||||
glu_compute<NPU_DATA_TYPE_F16>, // NPU_DATA_TYPE_F16
|
||||
}, false, // requires_thread_barrier
|
||||
},
|
||||
};
|
||||
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
|
||||
|
|
@ -370,6 +567,7 @@ static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
|
|||
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
|
||||
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
|
||||
static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE");
|
||||
static_assert(kOpCapabilities[NPU_OP_GLU].op == NPU_OP_GLU, "kOpArray[NPU_OP_GLU].op != NPU_OP_GLU");
|
||||
|
||||
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
|
||||
if (op >= NPU_OP_COUNT) {
|
||||
|
|
@ -395,17 +593,25 @@ bool requires_thread_barrier(npu_device_tensor_op op) {
|
|||
return kOpCapabilities[op].requires_thread_barrier;
|
||||
}
|
||||
|
||||
bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
auto is_supported_func = kOpCapabilities[op].is_supported;
|
||||
if (!is_supported_func || !is_supported_func(op, dst, srcs, src_len)) {
|
||||
bool support_op(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
if (!op_spec) {
|
||||
DEVICE_LOG_ERROR("[hexagon-npu]invalid op_spec\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto op = op_spec->op;
|
||||
auto is_supported_func = kOpCapabilities[op].is_supported;
|
||||
if (!is_supported_func || !is_supported_func(op_spec, dst, srcs, src_len)) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (get_compute_func_impl(op, dst->type) == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
|
||||
get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op), get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ compute_func_type get_compute_func(tensor * dst);
|
|||
|
||||
bool requires_thread_barrier(npu_device_tensor_op op);
|
||||
|
||||
bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
bool support_op(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ template <> struct convert_vector<npu_device_fp16_t> {
|
|||
};
|
||||
|
||||
template <auto _DotFunc, bool _ShouldCacheSrc0>
|
||||
void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst,
|
||||
void mul_mat_impl(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
hexagon::tensor * dst,
|
||||
hexagon::compute_params * params) {
|
||||
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
|
||||
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
|
||||
|
|
@ -62,8 +64,12 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl: no work to do, start_end_plane: (%ld, %ld), start_end_row: (%ld, %ld), "
|
||||
"start_end_element: (%ld, %ld)\n",
|
||||
start_end_plane.first, start_end_plane.second, start_end_row.first, start_end_row.second,
|
||||
start_end_element.first, start_end_element.second);
|
||||
start_end_plane.first,
|
||||
start_end_plane.second,
|
||||
start_end_row.first,
|
||||
start_end_row.second,
|
||||
start_end_element.first,
|
||||
start_end_element.second);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -81,7 +87,9 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
DEVICE_LOG_ERROR(
|
||||
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
|
||||
"src0_actual_row_size: %zu, will fallback to mem cache\n",
|
||||
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size);
|
||||
src0_plane_cache_size,
|
||||
src0_plane_slice_row_count,
|
||||
src0_actual_row_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -89,7 +97,10 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
|
||||
"%p(%zu)\n",
|
||||
src0_actual_row_size, src0_plane_slice_row_count, _ShouldCacheSrc0, (void *) src0_plane_cache_ptr,
|
||||
src0_actual_row_size,
|
||||
src0_plane_slice_row_count,
|
||||
_ShouldCacheSrc0,
|
||||
(void *) src0_plane_cache_ptr,
|
||||
src0_plane_cache_size);
|
||||
|
||||
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0);
|
||||
|
|
@ -99,7 +110,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
|
||||
uint8_t * dst_ptr = dst->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst,
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) dst,
|
||||
hexagon::get_type_name(dst->get_type()));
|
||||
return;
|
||||
}
|
||||
|
|
@ -114,16 +126,17 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2);
|
||||
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
|
||||
col_idx += src0_plane_slice_row_count) {
|
||||
const uint8_t * src0_plane =
|
||||
src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1);
|
||||
hexagon::l2fetch_row(src0_plane, src0->get_nb(1));
|
||||
|
||||
const int64_t actual_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - col_idx); // number of rows in this slice
|
||||
const uint8_t * src0_plane =
|
||||
src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1);
|
||||
if constexpr (_ShouldCacheSrc0) {
|
||||
if (last_cached_plane_ptr != src0_plane) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
|
||||
|
||||
hexagon::l2fetch_row(src0_plane, src0->get_nb(1));
|
||||
for (int64_t ir = 0; ir < actual_row_count; ir++) {
|
||||
auto * src0_row = src0_plane + ir * src0->get_nb(1);
|
||||
if (ir + 1 < actual_row_count) {
|
||||
|
|
@ -131,7 +144,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
}
|
||||
|
||||
auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size;
|
||||
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_target_type *>(cached_row_ptr),
|
||||
dequantize_row_func(src0_row,
|
||||
reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
|
||||
src0->get_ne(0));
|
||||
}
|
||||
|
||||
|
|
@ -158,12 +172,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res0 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res0);
|
||||
}
|
||||
reinterpret_cast<const data_type1 *>(src1_row),
|
||||
(size_t) src0->get_ne(0));
|
||||
|
||||
if (should_fetch_src0_row && i0 + 2 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes);
|
||||
|
|
@ -171,10 +181,12 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res1 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_actual_row_size),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
reinterpret_cast<const data_type1 *>(src1_row),
|
||||
(size_t) src0->get_ne(0));
|
||||
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res0);
|
||||
dst_row[i0 + 1] = convert_vector<data_type1>::convert(res1);
|
||||
}
|
||||
}
|
||||
|
|
@ -186,7 +198,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
if (i0 < actual_row_count) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
auto res = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
reinterpret_cast<const data_type1 *>(src1_row),
|
||||
(size_t) src0->get_ne(0));
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res);
|
||||
}
|
||||
|
|
@ -197,19 +210,194 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
dst->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
bool is_row_size_cacheable(const npu_device_tensor_spec & src) {
|
||||
const auto & type_traits = hexagon::get_type_traits(src.type);
|
||||
if (type_traits.to_float == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) cannot be cached, to_float is null\n",
|
||||
hexagon::get_type_name(src.type));
|
||||
template <auto _DotFunc, bool _ShouldCacheSrc0>
|
||||
void mul_mat_gemv_impl(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
hexagon::tensor * dst,
|
||||
hexagon::compute_params * params) {
|
||||
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
|
||||
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
|
||||
|
||||
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0);
|
||||
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
|
||||
if (_ShouldCacheSrc0 && dequantize_row_func == nullptr) {
|
||||
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
|
||||
return;
|
||||
}
|
||||
|
||||
auto start_end_element = std::pair<int64_t, int64_t>{ 0, dst->get_ne(0) };
|
||||
if (dst->get_ne(0) >= params->get_thread_count()) {
|
||||
start_end_element = params->get_work_slice(dst->get_ne(0));
|
||||
} else {
|
||||
DEVICE_LOG_ERROR("Unsupported src1 tensor shape for gemv: %s, ne: %ldx%ldx%ldx%ld\n",
|
||||
hexagon::get_type_name(src1->get_type()),
|
||||
src1->get_ne(0),
|
||||
src1->get_ne(1),
|
||||
src1->get_ne(2),
|
||||
src1->get_ne(3));
|
||||
return;
|
||||
}
|
||||
|
||||
if (start_end_element.second <= start_end_element.first) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl: no work to do, start_end_plane: [0, 1), start_end_row: [0, 1), "
|
||||
"start_end_element: [%ld, %ld)\n",
|
||||
start_end_element.first,
|
||||
start_end_element.second);
|
||||
return;
|
||||
}
|
||||
|
||||
// cache the src0 plane in VTCM
|
||||
size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first;
|
||||
size_t src0_plane_cache_size = 0;
|
||||
uint8_t * src0_plane_cache_ptr = nullptr;
|
||||
const auto src1_actual_row_size = hexagon::get_aligned_size(src1->get_nb(1));
|
||||
uint8_t * src1_row_cache_ptr = nullptr;
|
||||
if constexpr (_ShouldCacheSrc0) {
|
||||
src0_plane_slice_row_count = std::min(
|
||||
(params->get_vtcm_quota_size() - src1_actual_row_size) / src0_actual_row_size, src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
src0_plane_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size + src1_actual_row_size);
|
||||
if (src0_plane_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_ERROR(
|
||||
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
|
||||
"src0_actual_row_size: %zu, will fallback to mem cache\n",
|
||||
src0_plane_cache_size,
|
||||
src0_plane_slice_row_count,
|
||||
src0_actual_row_size);
|
||||
return;
|
||||
}
|
||||
|
||||
src1_row_cache_ptr = src0_plane_cache_ptr + src0_plane_cache_size;
|
||||
} else {
|
||||
src1_row_cache_ptr = params->get_vtcm_cache(src1_actual_row_size);
|
||||
if (src1_row_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: failed to get VTCM cache for src1, size: %zu\n", src1_actual_row_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
|
||||
"%p(%zu)\n",
|
||||
src0_actual_row_size,
|
||||
src0_plane_slice_row_count,
|
||||
_ShouldCacheSrc0,
|
||||
(void *) src0_plane_cache_ptr,
|
||||
src0_plane_cache_size);
|
||||
|
||||
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0);
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
|
||||
|
||||
uint8_t * dst_ptr = dst->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) dst,
|
||||
hexagon::get_type_name(dst->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr bool should_fetch_src0_row = !_ShouldCacheSrc0;
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
|
||||
{
|
||||
memcpy(src1_row_cache_ptr, src1_ptr, src1->get_ne(0) * sizeof(data_type1));
|
||||
src1_ptr = src1_row_cache_ptr;
|
||||
}
|
||||
|
||||
{
|
||||
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
|
||||
col_idx += src0_plane_slice_row_count) {
|
||||
const uint8_t * src0_plane = src0_ptr + col_idx * src0->get_nb(1);
|
||||
hexagon::l2fetch_row(src0_plane, src0->get_nb(1));
|
||||
|
||||
const int64_t actual_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - col_idx); // number of rows in this slice
|
||||
if constexpr (_ShouldCacheSrc0) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
|
||||
|
||||
for (int64_t ir = 0; ir < actual_row_count; ir++) {
|
||||
auto * src0_row = src0_plane + ir * src0->get_nb(1);
|
||||
if (ir + 1 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
|
||||
}
|
||||
|
||||
auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size;
|
||||
dequantize_row_func(
|
||||
src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr), src0->get_ne(0));
|
||||
}
|
||||
|
||||
src0_plane = src0_plane_cache_ptr;
|
||||
}
|
||||
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot);
|
||||
auto * dst_row = reinterpret_cast<float *>(dst_ptr) + col_idx;
|
||||
int64_t i0 = 0;
|
||||
for (; i0 + 1 < actual_row_count; i0 += 2) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
if constexpr (should_fetch_src0_row) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res0 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_ptr),
|
||||
(size_t) src0->get_ne(0));
|
||||
|
||||
if (should_fetch_src0_row && i0 + 2 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res1 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_actual_row_size),
|
||||
reinterpret_cast<const data_type1 *>(src1_ptr),
|
||||
(size_t) src0->get_ne(0));
|
||||
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res0);
|
||||
dst_row[i0 + 1] = convert_vector<data_type1>::convert(res1);
|
||||
}
|
||||
}
|
||||
|
||||
if (i0 < actual_row_count) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
auto res = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_ptr),
|
||||
(size_t) src0->get_ne(0));
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
bool is_src_cacheable(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
|
||||
const auto & src0_type_traits = hexagon::get_type_traits(src0.type);
|
||||
if (src0_type_traits.to_float == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) cannot be cached, to_float is null\n",
|
||||
hexagon::get_type_name(src0.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t type_size = type_traits.is_quantized ? sizeof(hexagon::dequant_target_type) : type_traits.type_size;
|
||||
const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota();
|
||||
if (src.ne[0] * type_size > vtcm_thread_quota_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n",
|
||||
hexagon::get_type_name(src.type), (long) src.ne[0], vtcm_thread_quota_size);
|
||||
const size_t src0_type_size =
|
||||
src0_type_traits.is_quantized ? sizeof(hexagon::dequant_output_type) : src0_type_traits.type_size;
|
||||
const auto & src1_type_traits = hexagon::get_type_traits(src1.type);
|
||||
const bool is_gemv = src1.ne[1] == 1 && src1.ne[2] == 1 && src1.ne[3] == 1;
|
||||
size_t min_cache_size = is_gemv ? (src1.ne[0] * src1_type_traits.type_size) : 0;
|
||||
min_cache_size += src0.ne[0] * src0_type_size;
|
||||
if (min_cache_size > vtcm_thread_quota_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) min_cache_size is too large: %ld, vtcm_thread_quota_size: %zu\n",
|
||||
hexagon::get_type_name(src0.type),
|
||||
(long) min_cache_size,
|
||||
vtcm_thread_quota_size);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -219,36 +407,41 @@ bool is_row_size_cacheable(const npu_device_tensor_spec & src) {
|
|||
bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
|
||||
if (src1.type != NPU_DATA_TYPE_F32 && src1.type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src1 is not F32\n",
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(src1.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto type_traits = hexagon::get_type_traits(src0.type);
|
||||
if (!type_traits.is_quantized || type_traits.to_float == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src0 is not quantized\n",
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(src1.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0.ne[0] % type_traits.blck_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) ne[0] is not aligned: %ld\n", hexagon::get_type_name(src0.type),
|
||||
(long) src0.ne[0]);
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[MUL_MAT]src0.type(%s) ne[0] is not aligned: %ld\n", hexagon::get_type_name(src0.type), (long) src0.ne[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_row_size_cacheable(src0)) {
|
||||
if (!is_src_cacheable(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]supported quantized src0.type(%s) and src1.type(%s)\n",
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(src1.type));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<float>();
|
||||
const auto * src0_ptr =
|
||||
is_src0_quantized ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>(); // skip src0 for quantized tensors
|
||||
bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
bool is_src0_cached,
|
||||
bool is_src1_cached) {
|
||||
const auto * src1_ptr = is_src1_cached ? nullptr : src1->get_read_buffer_as<float>();
|
||||
const auto * src0_ptr = is_src0_cached ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>();
|
||||
|
||||
if (!hexagon::is_f16_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
|
|
@ -305,35 +498,49 @@ bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::ten
|
|||
return true;
|
||||
}
|
||||
|
||||
typedef void (*mul_mat_func_type)(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst,
|
||||
typedef void (*mul_mat_func_type)(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
hexagon::tensor * dst,
|
||||
hexagon::compute_params * params);
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF32F32CachedFuncs[2] = {
|
||||
constexpr const size_t kMulMatGemvBaseIndex = 2;
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF32F32CachedFuncs[4] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized aligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized gemv
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF32F32Funcs[2] = {
|
||||
constexpr const mul_mat_func_type kMulMatF32F32Funcs[4] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 quantized aligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 quantized gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 quantized gemv
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16CachedFuncs[2] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, true>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16Funcs[2] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 quantized aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16F32Funcs[2] = {
|
||||
constexpr const mul_mat_func_type kMulMatF16F32Funcs[4] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F32 * F32 quantized aligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F32 * F32 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F32 * F32 quantized aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16CachedFuncs[4] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, true>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized gemv
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16Funcs[4] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 quantized gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 quantized gemv
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
|
@ -342,9 +549,9 @@ namespace hexagon {
|
|||
|
||||
bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
|
||||
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "mul_mat_f32 requires max dims 4");
|
||||
static_assert(std::is_same<hexagon::dequant_target_type, float>::value ||
|
||||
std::is_same<hexagon::dequant_target_type, npu_device_fp16_t>::value,
|
||||
"dequant_target_type must be float or npu_device_fp16_t");
|
||||
static_assert(std::is_same<hexagon::dequant_output_type, float>::value ||
|
||||
std::is_same<hexagon::dequant_output_type, npu_device_fp16_t>::value,
|
||||
"dequant_output_type must be float or npu_device_fp16_t");
|
||||
|
||||
if (!out) {
|
||||
return false;
|
||||
|
|
@ -358,24 +565,28 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
|
|||
|
||||
const bool is_src0_quantized = is_quantized_type(src0->get_type());
|
||||
const bool should_cache_src0 = is_src0_quantized || src1->get_ne(1) > 1;
|
||||
const bool is_gemv = src1->get_ne(1) == 1 && src1->get_ne(2) == 1 && src1->get_ne(3) == 1;
|
||||
const auto base_index = is_gemv ? kMulMatGemvBaseIndex : 0;
|
||||
switch (src1->get_type()) {
|
||||
case NPU_DATA_TYPE_F32:
|
||||
if (is_src0_quantized || src0->get_type() == NPU_DATA_TYPE_F16) {
|
||||
kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1,
|
||||
out, params);
|
||||
kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, is_src0_quantized, is_gemv) +
|
||||
base_index](src0, src1, out, params);
|
||||
} else if (should_cache_src0) {
|
||||
kMulMatF32F32CachedFuncs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params);
|
||||
kMulMatF32F32CachedFuncs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1) + base_index](
|
||||
src0, src1, out, params);
|
||||
} else {
|
||||
kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params);
|
||||
kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1) + base_index](
|
||||
src0, src1, out, params);
|
||||
}
|
||||
return true;
|
||||
case NPU_DATA_TYPE_F16:
|
||||
if (should_cache_src0) {
|
||||
kMulMatF16CachedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](
|
||||
src0, src1, out, params);
|
||||
kMulMatF16CachedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) +
|
||||
base_index](src0, src1, out, params);
|
||||
} else {
|
||||
kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1, out,
|
||||
params);
|
||||
kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) + base_index](
|
||||
src0, src1, out, params);
|
||||
}
|
||||
return true;
|
||||
default:
|
||||
|
|
@ -386,8 +597,11 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len) {
|
||||
bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_MUL_MAT) {
|
||||
DEVICE_LOG_DEBUG("op is not MUL_MAT: %d\n", op);
|
||||
return false;
|
||||
|
|
@ -408,7 +622,9 @@ bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec
|
|||
if (src0.type != src1.type) {
|
||||
if (src1.type == NPU_DATA_TYPE_F32 && src0.type == NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch, but src0 is F16 and src1 is F32\n",
|
||||
op_get_name(op), get_type_name(src0.type), get_type_name(src1.type));
|
||||
op_get_name(op),
|
||||
get_type_name(src0.type),
|
||||
get_type_name(src1.type));
|
||||
return true; // F16 * F32 is supported
|
||||
}
|
||||
|
||||
|
|
@ -418,26 +634,40 @@ bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec
|
|||
}
|
||||
#else
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch and quantized tensors are not supported\n",
|
||||
op_get_name(op), get_type_name(src0.type), get_type_name(src1.type));
|
||||
op_get_name(op),
|
||||
get_type_name(src0.type),
|
||||
get_type_name(src1.type));
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
if (src0.ne[0] != src1.ne[0] || src0.ne[1] != dst->ne[0]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[0],
|
||||
(long) src0.ne[1], (long) src1.ne[0], (long) src1.ne[1]);
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n",
|
||||
op_get_name(op),
|
||||
(long) src0.ne[0],
|
||||
(long) src0.ne[1],
|
||||
(long) src1.ne[0],
|
||||
(long) src1.ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1.ne[1] != dst->ne[1] || src1.ne[2] != dst->ne[2] || src1.ne[3] != dst->ne[3]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n", op_get_name(op),
|
||||
(long) src1.ne[2], (long) src1.ne[3], (long) dst->ne[2], (long) dst->ne[3]);
|
||||
DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n",
|
||||
op_get_name(op),
|
||||
(long) src1.ne[2],
|
||||
(long) src1.ne[3],
|
||||
(long) dst->ne[2],
|
||||
(long) dst->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1.ne[2] % src0.ne[2] || src1.ne[3] % src0.ne[3]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 cannot broadcast to src1: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[2],
|
||||
(long) src0.ne[3], (long) src1.ne[2], (long) src1.ne[3]);
|
||||
DEVICE_LOG_DEBUG("[%s]src0 cannot broadcast to src1: %ldx%ld vs %ldx%ld\n",
|
||||
op_get_name(op),
|
||||
(long) src0.ne[2],
|
||||
(long) src0.ne[3],
|
||||
(long) src1.ne[2],
|
||||
(long) src1.ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
#pragma once
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include "op_types.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool mul_mat_f32(tensor * out, compute_params * params);
|
||||
bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len);
|
||||
bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -29,8 +29,14 @@ float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta) {
|
||||
void rope_yarn(float theta_extrap,
|
||||
float freq_scale,
|
||||
float corr_dims[2],
|
||||
int64_t i0,
|
||||
float ext_factor,
|
||||
float mscale,
|
||||
float * cos_theta,
|
||||
float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
|
|
@ -45,8 +51,16 @@ void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t
|
|||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
||||
void rope_cache_init(float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0,
|
||||
float ext_factor, float mscale, float * cache, float sin_sign, float theta_scale) {
|
||||
void rope_cache_init(float theta_base,
|
||||
float freq_scale,
|
||||
const float * freq_factors,
|
||||
float corr_dims[2],
|
||||
int64_t ne0,
|
||||
float ext_factor,
|
||||
float mscale,
|
||||
float * cache,
|
||||
float sin_sign,
|
||||
float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
float theta = theta_base;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
|
|
@ -58,10 +72,21 @@ void rope_cache_init(float theta_base, float freq_scale, const float * freq_fact
|
|||
}
|
||||
}
|
||||
|
||||
void mrope_cache_init(float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e,
|
||||
const int sections[4], bool indep_sects, float freq_scale, const float * freq_factors,
|
||||
float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float * cache, float sin_sign,
|
||||
float theta_scale) {
|
||||
void mrope_cache_init(float theta_base_t,
|
||||
float theta_base_h,
|
||||
float theta_base_w,
|
||||
float theta_base_e,
|
||||
const int sections[4],
|
||||
bool indep_sects,
|
||||
float freq_scale,
|
||||
const float * freq_factors,
|
||||
float corr_dims[2],
|
||||
int64_t ne0,
|
||||
float ext_factor,
|
||||
float mscale,
|
||||
float * cache,
|
||||
float sin_sign,
|
||||
float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
float theta_t = theta_base_t;
|
||||
float theta_h = theta_base_h;
|
||||
|
|
@ -181,16 +206,37 @@ bool rope_impl(hexagon::tensor * out, hexagon::compute_params * params) {
|
|||
if constexpr (!_IsMrope) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache);
|
||||
const int64_t p = pos[i2];
|
||||
rope_cache_init(p, freq_scale, freq_factors, corr_dims, out->get_ne(0), ext_factor, attn_factor, cache,
|
||||
sin_sign, theta_scale);
|
||||
rope_cache_init(p,
|
||||
freq_scale,
|
||||
freq_factors,
|
||||
corr_dims,
|
||||
out->get_ne(0),
|
||||
ext_factor,
|
||||
attn_factor,
|
||||
cache,
|
||||
sin_sign,
|
||||
theta_scale);
|
||||
} else {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache);
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + out->get_ne(2)];
|
||||
const int64_t p_w = pos[i2 + out->get_ne(2) * 2];
|
||||
const int64_t p_e = pos[i2 + out->get_ne(2) * 3];
|
||||
mrope_cache_init(p_t, p_h, p_w, p_e, sections, _IsVision, freq_scale, freq_factors, corr_dims,
|
||||
out->get_ne(0), ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
mrope_cache_init(p_t,
|
||||
p_h,
|
||||
p_w,
|
||||
p_e,
|
||||
sections,
|
||||
_IsVision,
|
||||
freq_scale,
|
||||
freq_factors,
|
||||
corr_dims,
|
||||
out->get_ne(0),
|
||||
ext_factor,
|
||||
attn_factor,
|
||||
cache,
|
||||
sin_sign,
|
||||
theta_scale);
|
||||
}
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 1, loop);
|
||||
|
|
@ -316,8 +362,11 @@ bool rope_f32(tensor * out, compute_params * params) {
|
|||
return kRopeImplFuncs[impl_index](out, params);
|
||||
}
|
||||
|
||||
bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
bool is_rope_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_ROPE) {
|
||||
DEVICE_LOG_DEBUG("[%s]op is not ROPE\n", op_get_name(op));
|
||||
return false;
|
||||
|
|
@ -336,8 +385,10 @@ bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * d
|
|||
|
||||
const auto & src0 = srcs[0];
|
||||
if (src0.type != dst->type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 type is not the same as dst type: %s vs %s\n", op_get_name(op),
|
||||
get_type_name(src0.type), get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG("[%s]src0 type is not the same as dst type: %s vs %s\n",
|
||||
op_get_name(op),
|
||||
get_type_name(src0.type),
|
||||
get_type_name(dst->type));
|
||||
return false; // unsupported src0 type
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@
|
|||
namespace hexagon {
|
||||
|
||||
bool rope_f32(tensor * out, compute_params * params);
|
||||
bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
bool is_rope_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
#include "tensor.hpp"
|
||||
#include "thread_pool.hpp"
|
||||
#include "util.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -7,12 +13,6 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
#include "tensor.hpp"
|
||||
#include "thread_pool.hpp"
|
||||
#include "util.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
inline constexpr std::pair<int64_t, int64_t> get_thread_work_slice(int64_t total, size_t tidx, size_t tcnt) {
|
||||
|
|
@ -44,8 +44,6 @@ struct compute_params {
|
|||
|
||||
uint8_t * get_vtcm_cache(size_t size) { return thread_params->get_vtcm_cache(size); }
|
||||
|
||||
uint8_t * get_mem_cache(size_t size) { return thread_params->get_mem_cache(size); }
|
||||
|
||||
std::pair<int64_t, int64_t> get_work_slice(int64_t total) const {
|
||||
return get_thread_work_slice(total, thread_params->tidx, thread_params->tcnt);
|
||||
}
|
||||
|
|
@ -58,7 +56,9 @@ struct compute_params {
|
|||
};
|
||||
|
||||
typedef bool (*compute_func_type)(tensor * dst, compute_params * params);
|
||||
typedef bool (*op_is_supported_func_type)(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len);
|
||||
typedef bool (*op_is_supported_func_type)(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
#pragma once
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
#include "util.hpp"
|
||||
|
||||
#include <HAP_mem.h>
|
||||
#include <qurt.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
#include "util.hpp"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
constexpr const size_t kMaxTensorSrc = DEVICE_TENSOR_MAX_SRC;
|
||||
|
|
@ -26,8 +26,15 @@ class tensor {
|
|||
|
||||
_data = static_cast<uint8_t *>(mmap_address);
|
||||
DEVICE_LOG_INFO("tensor(%p[%ldx%ldx%ldx%ld]), fd: %d, offset: %zu, mmap_addr: %p, phy_addr: 0x%lx\n",
|
||||
(void *) this, (long) _info.ne[0], (long) _info.ne[1], (long) _info.ne[2], (long) _info.ne[3],
|
||||
_info.buffer_fd, _info.offset, (void *) mmap_address, phy_address);
|
||||
(void *) this,
|
||||
(long) _info.ne[0],
|
||||
(long) _info.ne[1],
|
||||
(long) _info.ne[2],
|
||||
(long) _info.ne[3],
|
||||
_info.buffer_fd,
|
||||
_info.offset,
|
||||
(void *) mmap_address,
|
||||
phy_address);
|
||||
}
|
||||
|
||||
~tensor() noexcept {
|
||||
|
|
@ -41,15 +48,17 @@ class tensor {
|
|||
|
||||
void flush() const {
|
||||
if (_data) {
|
||||
qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size, QURT_MEM_CACHE_FLUSH,
|
||||
QURT_MEM_DCACHE);
|
||||
qurt_mem_cache_clean(
|
||||
(qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size, QURT_MEM_CACHE_FLUSH, QURT_MEM_DCACHE);
|
||||
}
|
||||
}
|
||||
|
||||
void invalidate() const {
|
||||
if (_data) {
|
||||
qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size,
|
||||
QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE);
|
||||
qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset),
|
||||
(qurt_size_t) _info.size,
|
||||
QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL,
|
||||
QURT_MEM_DCACHE);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include "util.hpp"
|
||||
#include "vtcm_mem.hpp"
|
||||
|
||||
#include <qurt.h>
|
||||
|
||||
#include <array>
|
||||
|
|
@ -8,27 +11,26 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "util.hpp"
|
||||
#include "vtcm_mem.hpp"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
constexpr const size_t kMaxThreadCount = 4;
|
||||
constexpr const size_t kDefaultStackSize = 1024 * 32; // 32KB
|
||||
constexpr const unsigned long long kThreadTaskPendingBit = 1;
|
||||
constexpr const size_t kMaxThreadCount = 4;
|
||||
constexpr const size_t kDefaultStackSize = 1024 * 64; // 64KB
|
||||
|
||||
template <size_t _stack_size> class qurt_thread {
|
||||
public:
|
||||
typedef void (*qurt_thread_func_type)(qurt_thread * thread, void * arg);
|
||||
|
||||
explicit qurt_thread(const std::string & thread_name, qurt_thread_func_type thread_func, void * arg,
|
||||
unsigned short priority) {
|
||||
explicit qurt_thread(const std::string & thread_name,
|
||||
qurt_thread_func_type thread_func,
|
||||
void * arg,
|
||||
unsigned short priority) {
|
||||
DEVICE_LOG_DEBUG("qurt_thread.create: %s", thread_name.c_str());
|
||||
qurt_thread_attr_init(&_attributes);
|
||||
qurt_thread_attr_set_name(&_attributes, (char *) thread_name.c_str());
|
||||
qurt_thread_attr_set_stack_addr(&_attributes, _stack);
|
||||
qurt_thread_attr_set_stack_size(&_attributes, _stack_size);
|
||||
qurt_thread_attr_set_priority(&_attributes, priority);
|
||||
qurt_thread_attr_set_bus_priority(&_attributes, QURT_THREAD_BUS_PRIO_ENABLED);
|
||||
|
||||
_func = thread_func;
|
||||
_arg = arg;
|
||||
|
|
@ -94,9 +96,9 @@ template <size_t _ThreadCount> class thread_pool {
|
|||
thread_pool<kMaxThreadCount> * pool = nullptr;
|
||||
size_t vtcm_quota_size;
|
||||
|
||||
std::unique_ptr<vtcm_mem> vtcm_cache;
|
||||
std::unique_ptr<uint8_t[]> mem_cache;
|
||||
size_t mem_cache_size = 0;
|
||||
std::unique_ptr<vtcm_mem> vtcm_cache;
|
||||
|
||||
void init_vtcm_cache() { vtcm_cache = std::make_unique<vtcm_mem>(vtcm_quota_size, false); }
|
||||
|
||||
uint8_t * get_vtcm_cache(size_t size) {
|
||||
if (!vtcm_cache || vtcm_cache->get_size() < size) {
|
||||
|
|
@ -111,25 +113,18 @@ template <size_t _ThreadCount> class thread_pool {
|
|||
|
||||
return vtcm_cache->get_mem();
|
||||
}
|
||||
|
||||
uint8_t * get_mem_cache(size_t size) {
|
||||
if (!mem_cache || mem_cache_size < size) {
|
||||
mem_cache.reset(); // reset the cache to create a new one
|
||||
mem_cache = std::make_unique<uint8_t[]>(size + 256);
|
||||
mem_cache_size = mem_cache ? size : 0;
|
||||
}
|
||||
|
||||
return mem_cache.get();
|
||||
}
|
||||
};
|
||||
|
||||
typedef void (*task_type)(thread_pool * pool, thread_params * param, void * arg);
|
||||
|
||||
thread_pool() {
|
||||
for (size_t i = 0; i < kMaxThreadCount; ++i) {
|
||||
_thread_params[i].tidx = i;
|
||||
_thread_params[i].vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount;
|
||||
_thread_params[i].pool = this;
|
||||
auto & thread_param = _thread_params[i];
|
||||
thread_param.tidx = i;
|
||||
thread_param.vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount;
|
||||
thread_param.pool = this;
|
||||
|
||||
thread_param.init_vtcm_cache();
|
||||
}
|
||||
|
||||
qurt_barrier_init(&_pending, kMaxSubThreadCount + 1);
|
||||
|
|
@ -215,7 +210,8 @@ template <size_t _ThreadCount> class thread_pool {
|
|||
|
||||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
auto task_begin_cycles = pool._task_begin_cycles.load();
|
||||
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, prepare: %lluus", param->tidx,
|
||||
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, prepare: %lluus",
|
||||
param->tidx,
|
||||
static_cast<unsigned long long>(
|
||||
HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles)));
|
||||
#endif
|
||||
|
|
@ -229,7 +225,8 @@ template <size_t _ThreadCount> class thread_pool {
|
|||
qurt_barrier_wait(&pool._completed);
|
||||
|
||||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, task_end: %lluus", param->tidx,
|
||||
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, task_end: %lluus",
|
||||
param->tidx,
|
||||
static_cast<unsigned long long>(
|
||||
HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles)));
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
#include "type_traits.hpp"
|
||||
|
||||
#include "op_types.hpp" // TODO: remove this include
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "op_types.hpp" // TODO: remove this include
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
static_assert(sizeof(npu_device_block_q4_k) ==
|
||||
2 * sizeof(npu_device_fp16_t) + QUANT_K_SCALE_SIZE + QUANT_K_BLOCK_SIZE / 2,
|
||||
"wrong q4_K block size/padding");
|
||||
|
|
@ -82,8 +82,17 @@ inline int nearest_int(float fval) {
|
|||
return (i & 0x007fffff) - 0x00400000;
|
||||
}
|
||||
|
||||
float make_qkx2_quants(int n, int nmax, const float * x, const float * weights, uint8_t * L, float * the_min,
|
||||
uint8_t * Laux, float rmin, float rdelta, int nstep, bool use_mad) {
|
||||
float make_qkx2_quants(int n,
|
||||
int nmax,
|
||||
const float * x,
|
||||
const float * weights,
|
||||
uint8_t * L,
|
||||
float * the_min,
|
||||
uint8_t * Laux,
|
||||
float rmin,
|
||||
float rdelta,
|
||||
int nstep,
|
||||
bool use_mad) {
|
||||
float min = x[0];
|
||||
float max = x[0];
|
||||
float sum_w = weights[0];
|
||||
|
|
@ -315,13 +324,13 @@ void quantize_row_q4_K(const float * src, void * dst, size_t count) {
|
|||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q8_0(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
void dequantize_row_q8_0(const void * src, hexagon::dequant_output_type * dst, size_t count) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
|
||||
|
||||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q8_0 *>(src);
|
||||
auto * dst_ptr = ((hexagon::dequant_target_type *) dst); // TODO: opt for aligned access
|
||||
auto * dst_ptr = ((hexagon::dequant_output_type *) dst); // TODO: opt for aligned access
|
||||
|
||||
int i = 0;
|
||||
for (; i + 1 < nb; i += 2) {
|
||||
|
|
@ -354,7 +363,7 @@ void dequantize_row_q8_0(const void * src, hexagon::dequant_target_type * dst, s
|
|||
}
|
||||
|
||||
template <bool _IsDstAligned>
|
||||
void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * dst, size_t count) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
static_assert(qk % 2 == 0, "qk must be even");
|
||||
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
|
||||
|
|
@ -364,7 +373,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
|
|||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src);
|
||||
const HVX_Vector mask = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector minus = Q6_Vb_vsplat_R(8);
|
||||
hexagon::dequant_target_type * dst_ptr = dst; // TODO: opt for aligned access
|
||||
hexagon::dequant_output_type * dst_ptr = dst; // TODO: opt for aligned access
|
||||
|
||||
int i = 0;
|
||||
for (; i + 3 < nb; i += 4) {
|
||||
|
|
@ -402,7 +411,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
|
|||
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = q_hi;
|
||||
}
|
||||
|
||||
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type) * 2;
|
||||
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type) * 2;
|
||||
}
|
||||
|
||||
for (; i + 1 < nb; i += 2) {
|
||||
|
|
@ -428,7 +437,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
|
|||
*reinterpret_cast<HVX_UVector *>(dst_ptr) = q_lo;
|
||||
}
|
||||
|
||||
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type);
|
||||
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type);
|
||||
}
|
||||
|
||||
if (i < nb) {
|
||||
|
|
@ -453,7 +462,7 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
|
|||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_0(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
void dequantize_row_q4_0(const void * src, hexagon::dequant_output_type * dst, size_t count) {
|
||||
const bool dst_aligned = hexagon::is_addr_aligned(dst);
|
||||
if (dst_aligned) {
|
||||
dequantize_row_q4_0_impl<true>(src, dst, count);
|
||||
|
|
@ -462,7 +471,7 @@ void dequantize_row_q4_0(const void * src, hexagon::dequant_target_type * dst, s
|
|||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_K(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
void dequantize_row_q4_K(const void * src, hexagon::dequant_output_type * dst, size_t count) {
|
||||
const int nb = count / QUANT_K_BLOCK_SIZE;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_k *>(src);
|
||||
auto * dst_ptr = reinterpret_cast<__fp16 *>(dst);
|
||||
|
|
@ -497,29 +506,44 @@ void dequantize_row_q4_K(const void * src, hexagon::dequant_target_type * dst, s
|
|||
}
|
||||
}
|
||||
|
||||
void copy_row_f16(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
void copy_row_f16(const void * src, hexagon::dequant_output_type * dst, size_t count) {
|
||||
hexagon::vec_cpy_f16(reinterpret_cast<const npu_device_fp16_t *>(src), dst, count);
|
||||
}
|
||||
|
||||
void copy_row_f32(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
void copy_row_f32(const void * src, hexagon::dequant_output_type * dst, size_t count) {
|
||||
hexagon::vec_cpy_f32(reinterpret_cast<const float *>(src), reinterpret_cast<float *>(dst), count);
|
||||
}
|
||||
|
||||
constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
|
||||
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32, nullptr,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
|
||||
{ NPU_DATA_TYPE_F32,
|
||||
"F32", 1,
|
||||
sizeof(float),
|
||||
false, copy_row_f32,
|
||||
nullptr, hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f32_f32>,
|
||||
hexagon::type_erase_dot_func<hexagon::is_f32_f32_dot_product_aligned> },
|
||||
{ NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, copy_row_f16, quantize_row_fp16,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16>,
|
||||
{ NPU_DATA_TYPE_F16,
|
||||
"F16", 1,
|
||||
sizeof(npu_device_fp16_t),
|
||||
false, copy_row_f16,
|
||||
quantize_row_fp16, hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16>,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f16_f16>,
|
||||
hexagon::type_erase_dot_func<hexagon::is_f16_f16_dot_product_aligned> },
|
||||
{ NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false },
|
||||
{ NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0,
|
||||
{ NPU_DATA_TYPE_Q8_0,
|
||||
"Q8_0", QUANT_BLOCK_SIZE,
|
||||
sizeof(npu_device_block_q8_0),
|
||||
true, dequantize_row_q8_0,
|
||||
quantize_row_q8_0 },
|
||||
{ NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0,
|
||||
{ NPU_DATA_TYPE_Q4_0,
|
||||
"Q4_0", QUANT_BLOCK_SIZE,
|
||||
sizeof(npu_device_block_q4_0),
|
||||
true, dequantize_row_q4_0,
|
||||
quantize_row_q4_0 },
|
||||
{ NPU_DATA_TYPE_Q4_K, "Q4_K", QUANT_K_BLOCK_SIZE, sizeof(npu_device_block_q4_k), true, dequantize_row_q4_K,
|
||||
{ NPU_DATA_TYPE_Q4_K,
|
||||
"Q4_K", QUANT_K_BLOCK_SIZE,
|
||||
sizeof(npu_device_block_q4_k),
|
||||
true, dequantize_row_q4_K,
|
||||
quantize_row_q4_K },
|
||||
};
|
||||
|
||||
|
|
@ -566,7 +590,7 @@ size_t get_dequantized_row_size(const tensor * tensor) {
|
|||
|
||||
auto row_elems_count = tensor->get_ne(0);
|
||||
return hexagon::get_aligned_size(
|
||||
row_elems_count * sizeof(dequant_target_type)); // dequant_target_type is currently restricted to f32
|
||||
row_elems_count * sizeof(dequant_output_type)); // dequant_output_type is currently restricted to f32
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@
|
|||
|
||||
namespace hexagon {
|
||||
|
||||
using dequant_target_type = npu_device_fp16_t;
|
||||
using dequant_output_type = npu_device_fp16_t;
|
||||
|
||||
bool init_f16_f32_table(float * table, size_t count);
|
||||
|
||||
typedef void (*quantize_row_type)(const float * src, void * dst, size_t count);
|
||||
typedef void (*dequantize_row_type)(const void * src, dequant_target_type * dst, size_t count);
|
||||
typedef void (*dequantize_row_type)(const void * src, dequant_output_type * dst, size_t count);
|
||||
typedef float (*vec_dot_type)(const void * src0, const void * src1, size_t count);
|
||||
typedef bool (*can_use_aligned_vec_dot_type)(const void * src0, const void * src1, size_t count);
|
||||
|
||||
|
|
@ -51,14 +51,32 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx) {
|
|||
auto * src1 = op->get_src(1);
|
||||
char buffer[1024];
|
||||
if (src1 == nullptr) {
|
||||
snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s], tidx: %zu", op_get_name(op->get_op()),
|
||||
src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3), get_type_name(src0->get_type()),
|
||||
snprintf(buffer,
|
||||
sizeof(buffer),
|
||||
"[%s][%lldx%lldx%lldx%lld%s], tidx: %zu",
|
||||
op_get_name(op->get_op()),
|
||||
src0->get_ne(0),
|
||||
src0->get_ne(1),
|
||||
src0->get_ne(2),
|
||||
src0->get_ne(3),
|
||||
get_type_name(src0->get_type()),
|
||||
tidx);
|
||||
} else {
|
||||
snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s],[%lldx%lldx%lldx%lld%s], tidx: %zu",
|
||||
op_get_name(op->get_op()), src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3),
|
||||
get_type_name(src0->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2), src1->get_ne(3),
|
||||
get_type_name(src1->get_type()), tidx);
|
||||
snprintf(buffer,
|
||||
sizeof(buffer),
|
||||
"[%s][%lldx%lldx%lldx%lld%s],[%lldx%lldx%lldx%lld%s], tidx: %zu",
|
||||
op_get_name(op->get_op()),
|
||||
src0->get_ne(0),
|
||||
src0->get_ne(1),
|
||||
src0->get_ne(2),
|
||||
src0->get_ne(3),
|
||||
get_type_name(src0->get_type()),
|
||||
src1->get_ne(0),
|
||||
src1->get_ne(1),
|
||||
src1->get_ne(2),
|
||||
src1->get_ne(3),
|
||||
get_type_name(src1->get_type()),
|
||||
tidx);
|
||||
}
|
||||
return npu_scoped_timer<1024>(buffer);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
|
||||
#include <AEEStdDef.h>
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
|
@ -9,8 +11,6 @@
|
|||
#include <cstring>
|
||||
#include <utility>
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
|
||||
#define DEVICE_LOG_ERROR(...) FARF(FATAL, __VA_ARGS__)
|
||||
#define DEVICE_LOG_WARN(...) FARF(ERROR, __VA_ARGS__)
|
||||
#define DEVICE_LOG_INFO(...) FARF(HIGH, __VA_ARGS__)
|
||||
|
|
@ -56,6 +56,8 @@ inline constexpr const char * op_get_name(npu_device_tensor_op op) {
|
|||
return "FLASH_ATTN_EXT";
|
||||
case NPU_OP_ROPE:
|
||||
return "ROPE";
|
||||
case NPU_OP_GLU:
|
||||
return "GLU";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
|
@ -252,44 +254,68 @@ template <size_t _buffer_count> class npu_scoped_timer {
|
|||
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
|
||||
(unsigned long long) sub_proc2_duration, _sub_proc_data[3].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[3].proc_count, (unsigned long long) sub_proc3_duration);
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration,
|
||||
_sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count,
|
||||
(unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[2].proc_count,
|
||||
(unsigned long long) sub_proc2_duration,
|
||||
_sub_proc_data[3].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[3].proc_count,
|
||||
(unsigned long long) sub_proc3_duration);
|
||||
break;
|
||||
case 3:
|
||||
DEVICE_LOG_WARN(
|
||||
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration,
|
||||
_sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count,
|
||||
(unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[2].proc_count,
|
||||
(unsigned long long) sub_proc2_duration);
|
||||
break;
|
||||
case 2:
|
||||
DEVICE_LOG_WARN(
|
||||
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration);
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration,
|
||||
_sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count,
|
||||
(unsigned long long) sub_proc1_duration);
|
||||
break;
|
||||
case 1:
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", _log_prefix,
|
||||
(unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration);
|
||||
break;
|
||||
default:
|
||||
case 0:
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", _log_prefix,
|
||||
(unsigned long long) total_pcycles, (unsigned long long) duration);
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n",
|
||||
_log_prefix,
|
||||
(unsigned long long) total_pcycles,
|
||||
(unsigned long long) duration);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -317,8 +343,8 @@ template <size_t _buffer_count, size_t _sub_idx> class npu_sub_process_scoped_ti
|
|||
}
|
||||
|
||||
~npu_sub_process_scoped_timer() {
|
||||
_timer.add_sub_proc_cycles(_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles,
|
||||
HAP_perf_get_pcycles() - _begin_pcycles);
|
||||
_timer.add_sub_proc_cycles(
|
||||
_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles, HAP_perf_get_pcycles() - _begin_pcycles);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,13 +1,29 @@
|
|||
#pragma once
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
template <typename T, int N> struct HEXAGON_pack {
|
||||
T val[N];
|
||||
};
|
||||
|
||||
using HVX_Vector_x2 = std::pair<HVX_Vector, HVX_Vector>;
|
||||
using HVX_VectorPair_x4 = HEXAGON_pack<HVX_VectorPair, 4>;
|
||||
|
||||
typedef union {
|
||||
HVX_VectorPair VV;
|
||||
|
||||
struct {
|
||||
HVX_Vector lo;
|
||||
HVX_Vector hi;
|
||||
} V;
|
||||
} HVX_DV;
|
||||
|
||||
constexpr const size_t kBytesPerVector = sizeof(HVX_Vector); // 128 for v73
|
||||
constexpr const size_t kAlignMask = kBytesPerVector - 1;
|
||||
|
||||
|
|
@ -63,67 +79,6 @@ inline void l2fetch_row(const uint8_t * row_ptr, size_t bytes) {
|
|||
hexagon::l2fetch(row_ptr, kBytesPerVector, kBytesPerVector, l2fetch_vectors, 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* This function converts a vector of IEEE float elements to a vector of qf32 elements
|
||||
* See also: libs\qfe\inc\qhmath_hvx_convert.h
|
||||
*/
|
||||
inline HVX_Vector qhmath_hvx_vqf32_convert_vsf(HVX_Vector vin) {
|
||||
return Q6_Vqf32_vadd_VsfVsf(vin, Q6_V_vzero());
|
||||
}
|
||||
|
||||
/*
|
||||
* This function converts a vector of IEEE half float elements to a vector of qf16 elements
|
||||
* See also: libs\qfe\inc\qhmath_hvx_convert.h
|
||||
*/
|
||||
inline HVX_Vector qhmath_hvx_vqf16_convert_vhf(HVX_Vector vin) {
|
||||
return Q6_Vqf16_vadd_VhfVhf(vin, Q6_V_vzero());
|
||||
}
|
||||
|
||||
/*
|
||||
* This function converts a pair of vectors of qf32 elements to a vector of IEEE half float elements
|
||||
* See also: libs\qfe\inc\qhmath_hvx_convert.h
|
||||
*/
|
||||
inline HVX_Vector qhmath_hvx_vhf_convert_vqf32(HVX_VectorPair vin_vp) {
|
||||
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(vin_vp));
|
||||
}
|
||||
|
||||
/*
|
||||
* This function converts a vector of qf16 elements to a pair of vectors of qf32 elements
|
||||
* See also: libs\qfe\inc\qhmath_hvx_convert.h
|
||||
*/
|
||||
inline HVX_VectorPair qhmath_hvx_vqf32_convert_vqf16(HVX_Vector vxl) {
|
||||
HVX_VectorPair vxw_vp, exponent_vp;
|
||||
HVX_Vector mantissa_mask = Q6_Vh_vsplat_R(0xffe0);
|
||||
HVX_Vector exp_mask = Q6_Vh_vsplat_R(0x1f);
|
||||
HVX_Vector exp_offset = Q6_Vh_vsplat_R(0x70);
|
||||
HVX_Vector mant32_shift = Q6_Vh_vsplat_R(0x10);
|
||||
HVX_Vector reql, reqh, vxl_w, vxh_w, mantissa;
|
||||
HVX_Vector el_exponent, eh_exponent;
|
||||
|
||||
el_exponent = Q6_V_vand_VV(exp_mask, vxl);
|
||||
// Obtain the mantissa part: bits (5-15)
|
||||
mantissa = Q6_V_vand_VV(mantissa_mask, vxl);
|
||||
// Convert qf16 biassed exponent to qf32 biased exponent
|
||||
// new exp = exp + ( 127 (qf32 bias) -15(qf16 biass) ) = 112
|
||||
el_exponent = Q6_Vh_vadd_VhVh(exp_offset, el_exponent);
|
||||
|
||||
vxw_vp = Q6_Ww_vunpack_Vh(mantissa);
|
||||
vxl_w = Q6_V_lo_W(vxw_vp);
|
||||
vxh_w = Q6_V_hi_W(vxw_vp);
|
||||
|
||||
exponent_vp = Q6_Ww_vunpack_Vh(el_exponent);
|
||||
el_exponent = Q6_V_lo_W(exponent_vp);
|
||||
eh_exponent = Q6_V_hi_W(exponent_vp);
|
||||
// Convert q16 mantiss to q32 mantissa
|
||||
reql = Q6_Vw_vasl_VwVw(vxl_w, mant32_shift);
|
||||
reqh = Q6_Vw_vasl_VwVw(vxh_w, mant32_shift);
|
||||
// Add the exponent
|
||||
vxl_w = Q6_Vw_vadd_VwVw(reql, el_exponent);
|
||||
vxh_w = Q6_Vw_vadd_VwVw(reqh, eh_exponent);
|
||||
|
||||
return Q6_W_vcombine_VV(vxh_w, vxl_w);
|
||||
}
|
||||
|
||||
template <uint32_t _TyBytes> inline void q6op_vstu_variable_ARV(void * addr, HVX_Vector vin) {
|
||||
vin = Q6_V_vlalign_VVR(vin, vin, (size_t) addr); //rotate as needed.
|
||||
uint32_t left_off = unaligned_bytes(addr);
|
||||
|
|
@ -157,20 +112,6 @@ inline void q6op_vstu_variable_ARV(void * addr, int n, HVX_Vector vin) {
|
|||
Q6_vmaskedstorenq_QAV(qL_not, (HVX_Vector *) addr, vin);
|
||||
}
|
||||
|
||||
inline HVX_VectorPair hvx_vqf32_convert_vhf(HVX_Vector vxl) {
|
||||
return qhmath_hvx_vqf32_convert_vqf16(qhmath_hvx_vqf16_convert_vhf(vxl));
|
||||
}
|
||||
|
||||
using HVX_Vector_Dual = std::pair<HVX_Vector, HVX_Vector>;
|
||||
|
||||
inline HVX_Vector_Dual hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) {
|
||||
HVX_VectorPair res = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vxl), one);
|
||||
return {
|
||||
Q6_Vsf_equals_Vqf32(Q6_V_lo_W(res)),
|
||||
Q6_Vsf_equals_Vqf32(Q6_V_hi_W(res)),
|
||||
};
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) {
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
static_assert(kFloatsPerVector == 32, "kFloatsPerVector should be 32");
|
||||
|
|
@ -247,6 +188,7 @@ inline HVX_Vector hvx_passthru(HVX_Vector src, HVX_UVector *, HVX_Vector) {
|
|||
|
||||
} // namespace hexagon
|
||||
|
||||
#include "vec_math.inl"
|
||||
#include "vec_ops.inl"
|
||||
|
||||
namespace hexagon {
|
||||
|
|
@ -313,8 +255,8 @@ inline HVX_Vector vec_dot_product_vqf32_f32_f32(const float * src0, const float
|
|||
|
||||
inline HVX_Vector vec_dot_product_aligned_vqf32_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_aligned_impl<float, HVX_Vector, vec_mpy_qf32, vec_add_qf32, vec_reduction_qf32>(src0, src1,
|
||||
count);
|
||||
return vec_dot_product_aligned_impl<float, HVX_Vector, vec_mpy_qf32, vec_add_qf32, vec_reduction_qf32>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
|
|
@ -324,23 +266,25 @@ inline float vec_dot_product_f32_f32(const float * src0, const float * src1, siz
|
|||
|
||||
inline float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_aligned_impl<float, float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1,
|
||||
count);
|
||||
return vec_dot_product_aligned_impl<float, float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
inline bool is_f32_f32_dot_product_aligned(const float * src0, const float * src1, size_t count) {
|
||||
return is_dot_product_aligned<float, float>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_dot_product_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
inline HVX_Vector vec_dot_product_vqf16_f16_f16(const npu_device_fp16_t * src0,
|
||||
const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_impl<npu_device_fp16_t, HVX_Vector, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_dot_product_aligned_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
inline HVX_Vector vec_dot_product_aligned_vqf16_f16_f16(const npu_device_fp16_t * src0,
|
||||
const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_aligned_impl<npu_device_fp16_t, HVX_Vector, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16>(
|
||||
src0, src1, count);
|
||||
|
|
@ -352,41 +296,68 @@ inline float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_d
|
|||
src0, src1, count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
inline float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0,
|
||||
const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_aligned_impl<npu_device_fp16_t, float, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0,
|
||||
const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
return is_dot_product_aligned<npu_device_fp16_t, npu_device_fp16_t>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_dot_product_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, HVX_Vector, hvx_vsf_convert_vhf, vec_mpy_qf32,
|
||||
vec_add_qf32, vec_reduction_qf32>(src0, src1, count);
|
||||
using namespace hexagon::vec::math;
|
||||
return vec_dot_product_mixed_impl<npu_device_fp16_t,
|
||||
float,
|
||||
HVX_Vector,
|
||||
hvx_vsf_convert_vhf,
|
||||
vec_mpy_qf32,
|
||||
vec_add_qf32,
|
||||
vec_reduction_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1,
|
||||
size_t count) {
|
||||
inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0,
|
||||
const float * src1,
|
||||
size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, HVX_Vector, hvx_vsf_convert_vhf, vec_mpy_qf32,
|
||||
vec_add_qf32, vec_reduction_qf32>(src0, src1, count);
|
||||
using namespace hexagon::vec::math;
|
||||
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t,
|
||||
float,
|
||||
HVX_Vector,
|
||||
hvx_vsf_convert_vhf,
|
||||
vec_mpy_qf32,
|
||||
vec_add_qf32,
|
||||
vec_reduction_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
|
||||
using namespace hexagon::vec::math;
|
||||
return vec_dot_product_mixed_impl<npu_device_fp16_t,
|
||||
float,
|
||||
float,
|
||||
hvx_vsf_convert_vhf,
|
||||
vec_mpy_qf32,
|
||||
vec_add_qf32,
|
||||
vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, float, hvx_vsf_convert_vhf, vec_mpy_qf32,
|
||||
vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
|
||||
using namespace hexagon::vec::math;
|
||||
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t,
|
||||
float,
|
||||
float,
|
||||
hvx_vsf_convert_vhf,
|
||||
vec_mpy_qf32,
|
||||
vec_add_qf32,
|
||||
vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline bool is_f16_f32_dot_product_aligned(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
|
|
@ -409,4 +380,35 @@ _TReturn type_erase_dot_func(const void * src0, const void * src1, size_t count)
|
|||
return _DotFunc(src0_typed, src1_typed, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_silu_f32_f32(HVX_Vector x, HVX_VectorPair_x4 coeff) {
|
||||
using namespace hexagon::vec::math;
|
||||
|
||||
HVX_Vector one = Q6_V_vsplat_R(0x3F800000);
|
||||
|
||||
// x/(1.0f + expf(-x));
|
||||
HVX_Vector exp_neg_x = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(Q6_V_vzero(), x));
|
||||
HVX_Vector denom = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(qhmath_hvx_exp_vf(exp_neg_x), one));
|
||||
return qhmath_hvx_div_vf(x, denom, coeff);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_silu_f16_f16(HVX_Vector x, HVX_VectorPair_x4 coeff) {
|
||||
using namespace hexagon::vec::math;
|
||||
HVX_Vector one = Q6_Vh_vsplat_R(0x3c00);
|
||||
|
||||
// x/(1.0f + expf(-x));
|
||||
HVX_Vector exp_neg_x = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(Q6_V_vzero(), x));
|
||||
HVX_Vector denom = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_VhfVhf(qhmath_hvx_exp_vhf(exp_neg_x), one));
|
||||
return qhmath_hvx_div_vhf(x, denom, coeff);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_swiglu_f32_f32(HVX_Vector x, HVX_Vector g, HVX_VectorPair_x4 coeff) {
|
||||
HVX_Vector silu = vec_silu_f32_f32(x, coeff);
|
||||
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(silu, g));
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_swiglu_f16_f16(HVX_Vector x, HVX_Vector g, HVX_VectorPair_x4 coeff) {
|
||||
HVX_Vector silu = vec_silu_f16_f16(x, coeff);
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(silu, g));
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -1,15 +1,18 @@
|
|||
#pragma once
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
|
||||
namespace hexagon::vec {
|
||||
|
||||
template <typename _TElem, typename _TRet, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), _TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
template <typename _TElem,
|
||||
typename _TRet,
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
_TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
inline _TRet vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
||||
|
||||
|
|
@ -95,8 +98,11 @@ inline _TRet vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size
|
|||
return _ReduceFunc(sum);
|
||||
}
|
||||
|
||||
template <typename _TElem, typename _TRet, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), _TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
template <typename _TElem,
|
||||
typename _TRet,
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
_TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
inline _TRet vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
||||
|
||||
|
|
@ -170,8 +176,12 @@ inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) {
|
|||
return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1, typename _TRet, HVX_Vector_Dual (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
template <typename _TElem0,
|
||||
typename _TElem1,
|
||||
typename _TRet,
|
||||
HVX_Vector_x2 (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
_TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
||||
|
|
@ -202,8 +212,8 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr
|
|||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV);
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector_x2 s0_pair = _ExpandFunc(s0, kOneV);
|
||||
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
sum0 = _AddFunc(_MpyFunc(s0_pair.first, l1), sum0);
|
||||
|
|
@ -228,8 +238,8 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr
|
|||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV);
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector_x2 s0_pair = _ExpandFunc(s0, kOneV);
|
||||
|
||||
const bool has_remaining_src1_vector = src1_vec_ptr_end - src1_vec_ptr > 0;
|
||||
if (has_remaining_src1_vector) {
|
||||
|
|
@ -262,7 +272,7 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr
|
|||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
HVX_Vector_Dual curr0_pair = _ExpandFunc(curr0, kOneV);
|
||||
HVX_Vector_x2 curr0_pair = _ExpandFunc(curr0, kOneV);
|
||||
|
||||
curr0 = leftover1 == leftover0 ? curr0_pair.first : curr0_pair.second;
|
||||
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes1), sum);
|
||||
|
|
@ -271,8 +281,12 @@ inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * sr
|
|||
return _ReduceFunc(sum);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1, typename _TRet, HVX_Vector_Dual (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
template <typename _TElem0,
|
||||
typename _TElem1,
|
||||
typename _TRet,
|
||||
HVX_Vector_x2 (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
_TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
||||
|
|
@ -297,16 +311,16 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem
|
|||
HVX_Vector sum3 = Q6_V_vzero();
|
||||
|
||||
do {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_Vector_Dual curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV);
|
||||
HVX_VectorPair curr10 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
sum0 = _AddFunc(_MpyFunc(curr00.first, Q6_V_lo_W(curr10)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(curr00.second, Q6_V_hi_W(curr10)), sum1);
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_Vector_x2 curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV);
|
||||
HVX_VectorPair curr10 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
sum0 = _AddFunc(_MpyFunc(curr00.first, Q6_V_lo_W(curr10)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(curr00.second, Q6_V_hi_W(curr10)), sum1);
|
||||
|
||||
HVX_Vector_Dual curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV);
|
||||
HVX_VectorPair curr11 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[1];
|
||||
sum2 = _AddFunc(_MpyFunc(curr01.first, Q6_V_lo_W(curr11)), sum2);
|
||||
sum3 = _AddFunc(_MpyFunc(curr01.second, Q6_V_hi_W(curr11)), sum3);
|
||||
HVX_Vector_x2 curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV);
|
||||
HVX_VectorPair curr11 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[1];
|
||||
sum2 = _AddFunc(_MpyFunc(curr01.first, Q6_V_lo_W(curr11)), sum2);
|
||||
sum3 = _AddFunc(_MpyFunc(curr01.second, Q6_V_hi_W(curr11)), sum3);
|
||||
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 4;
|
||||
|
|
@ -317,8 +331,8 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem
|
|||
}
|
||||
|
||||
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_Vector_Dual s0_pair = _ExpandFunc(curr0, kOneV);
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_Vector_x2 s0_pair = _ExpandFunc(curr0, kOneV);
|
||||
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
sum0 = _AddFunc(_MpyFunc(s0_pair.first, Q6_V_lo_W(curr1)), sum0);
|
||||
|
|
@ -328,7 +342,8 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem
|
|||
return _ReduceFunc(_AddFunc(sum0, sum1));
|
||||
}
|
||||
|
||||
template <HVX_Vector (*_Func)(HVX_Vector, HVX_UVector *, HVX_Vector), HVX_Vector (*_FuncScaleConvert)(float),
|
||||
template <HVX_Vector (*_Func)(HVX_Vector, HVX_UVector *, HVX_Vector),
|
||||
HVX_Vector (*_FuncScaleConvert)(float),
|
||||
typename _TParam>
|
||||
inline void vec_scale_impl(const _TParam * src, float scale, _TParam * dst, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TParam);
|
||||
|
|
@ -410,7 +425,7 @@ template <typename _TData> inline void vec_zero_impl(_TData * src, size_t count)
|
|||
}
|
||||
|
||||
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector), typename _TyData>
|
||||
inline void vec_trans_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) {
|
||||
inline void vec_trans_impl(const _TyData * src0, const _TyData * src1, _TyData * dst, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
|
|
@ -496,4 +511,95 @@ inline void vec_trans_op_impl(const _TyData * src0, const _TyData * src1, size_t
|
|||
}
|
||||
}
|
||||
|
||||
template <typename _TyData, typename _TyParam, HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector, _TyParam)>
|
||||
inline void vec_trans_with_param_impl(const _TyData * src0,
|
||||
const _TyData * src1,
|
||||
_TyData * dst,
|
||||
size_t count,
|
||||
_TyParam param) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector * dst_vec_ptr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
|
||||
{
|
||||
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(l0, l1, param);
|
||||
|
||||
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
dst_vec_ptr[1] = _OpBinaryTransform(h0, h1, param);
|
||||
|
||||
prev0 = Q6_V_hi_W(curr0);
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
dst_vec_ptr += 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = *src0_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param);
|
||||
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
dst_vec_ptr++;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
|
||||
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1, param);
|
||||
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
dst_vec_ptr++;
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
const size_t leftover_bytes = leftover * sizeof(_TyData);
|
||||
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src0_vec_ptr :
|
||||
prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1, param));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace hexagon::vec
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
#include "util.hpp"
|
||||
|
||||
#include <HAP_compute_res.h>
|
||||
#include <HAP_vtcm_mgr.h>
|
||||
|
||||
#include "util.hpp"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
class vtcm_mem {
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@
|
|||
#include <domain_default.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "graph.hpp"
|
||||
#include "util.hpp"
|
||||
|
||||
#include <remote.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "graph.hpp"
|
||||
#include "util.hpp"
|
||||
|
||||
#define SKEL_URI_DEFINE(arch) ("file:///libhexagon_npu_skel_" arch ".so?npu_device_skel_handle_invoke&_modver=1.0")
|
||||
|
||||
namespace {
|
||||
|
|
@ -68,8 +68,10 @@ void backend_free(ggml_backend_t backend) {
|
|||
delete get_backend_object(backend);
|
||||
}
|
||||
|
||||
bool backend_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src,
|
||||
ggml_tensor * dst) {
|
||||
bool backend_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
ggml_backend_t backend_dst,
|
||||
const ggml_tensor * src,
|
||||
ggml_tensor * dst) {
|
||||
// TODO: implement this
|
||||
return false;
|
||||
}
|
||||
|
|
@ -194,13 +196,24 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) {
|
|||
|
||||
boolean supported = false;
|
||||
auto dst_spec = get_spec(op);
|
||||
auto ret = npu_device_device_support_op(_device_handle, npu_op, &dst_spec, srcs, i, &supported);
|
||||
|
||||
npu_device_tensor_op_spec npu_op_spec = { npu_op, {} };
|
||||
static_assert(sizeof(npu_op_spec.params) <= sizeof(op->op_params),
|
||||
"npu_op_spec.params size should less than op->op_params size");
|
||||
memcpy(npu_op_spec.params, op->op_params, sizeof(npu_op_spec.params));
|
||||
auto ret = npu_device_device_support_op(_device_handle, &npu_op_spec, &dst_spec, srcs, i, &supported);
|
||||
if (ret != AEE_SUCCESS || !supported) {
|
||||
#ifndef NDEBUG
|
||||
auto * src0_type = i ? ggml_type_name(op->src[0]->type) : "null";
|
||||
auto * src1_type = (i > 1) ? ggml_type_name(op->src[1]->type) : "null";
|
||||
LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", get_name(), ggml_op_desc(op),
|
||||
ggml_type_name(op->type), src0_type, src1_type, ret, supported);
|
||||
LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n",
|
||||
get_name(),
|
||||
ggml_op_desc(op),
|
||||
ggml_type_name(op->type),
|
||||
src0_type,
|
||||
src1_type,
|
||||
ret,
|
||||
supported);
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
|
@ -244,8 +257,8 @@ bool npu_device::init_device_lib() {
|
|||
}
|
||||
|
||||
if (err != AEE_SUCCESS) {
|
||||
LOG_ERROR("[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err,
|
||||
device_lib_uri.c_str());
|
||||
LOG_ERROR(
|
||||
"[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err, device_lib_uri.c_str());
|
||||
_device_handle = 0;
|
||||
return false;
|
||||
}
|
||||
|
|
@ -275,16 +288,26 @@ bool npu_device::supports_op(const ggml_tensor * op) {
|
|||
if (op->op != GGML_OP_NONE && op->op != GGML_OP_VIEW && op->op != GGML_OP_RESHAPE &&
|
||||
op->op != GGML_OP_PERMUTE) {
|
||||
_supported_op++;
|
||||
LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op),
|
||||
ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load());
|
||||
LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n",
|
||||
get_name(),
|
||||
ggml_op_desc(op),
|
||||
ggml_get_name(op),
|
||||
op_desc,
|
||||
_supported_op.load(),
|
||||
_unsupported_op.load());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
_unsupported_op++;
|
||||
LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op),
|
||||
ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load());
|
||||
LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n",
|
||||
get_name(),
|
||||
ggml_op_desc(op),
|
||||
ggml_get_name(op),
|
||||
op_desc,
|
||||
_supported_op.load(),
|
||||
_unsupported_op.load());
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ enum npu_device_tensor_op op_to_npu_op(ggml_op op) {
|
|||
return NPU_OP_FLASH_ATTN;
|
||||
case GGML_OP_ROPE:
|
||||
return NPU_OP_ROPE;
|
||||
case GGML_OP_GLU:
|
||||
return NPU_OP_GLU;
|
||||
default:
|
||||
return NPU_OP_COUNT;
|
||||
}
|
||||
|
|
@ -56,6 +58,8 @@ const char * get_npu_op_desc(enum npu_device_tensor_op op) {
|
|||
return ggml_op_name(GGML_OP_FLASH_ATTN_EXT);
|
||||
case NPU_OP_ROPE:
|
||||
return ggml_op_name(GGML_OP_ROPE);
|
||||
case NPU_OP_GLU:
|
||||
return ggml_op_name(GGML_OP_GLU);
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ interface npu_device : remote_handle64{
|
|||
|
||||
typedef int64_t ne_type[DEVICE_TENSOR_MAX_DIMS];
|
||||
typedef uint64_t nb_type[DEVICE_TENSOR_MAX_DIMS];
|
||||
typedef int32_t param_type[DEVICE_TENSOR_MAX_OP_PARAMS];
|
||||
typedef uint64_t tensor_handle_t;
|
||||
typedef uint64_t graph_handle_t;
|
||||
|
||||
|
|
@ -50,9 +51,19 @@ interface npu_device : remote_handle64{
|
|||
NPU_OP_RMS_NORM,
|
||||
NPU_OP_FLASH_ATTN,
|
||||
NPU_OP_ROPE,
|
||||
NPU_OP_GLU,
|
||||
NPU_OP_COUNT
|
||||
};
|
||||
|
||||
enum glu_op {
|
||||
NPU_GLU_OP_REGLU,
|
||||
NPU_GLU_OP_GEGLU,
|
||||
NPU_GLU_OP_SWIGLU,
|
||||
NPU_GLU_OP_GEGLU_ERF,
|
||||
NPU_GLU_OP_GEGLU_QUICK,
|
||||
NPU_GLU_OP_COUNT
|
||||
};
|
||||
|
||||
enum tensor_data_type {
|
||||
NPU_DATA_TYPE_F32,
|
||||
NPU_DATA_TYPE_F16,
|
||||
|
|
@ -69,9 +80,14 @@ interface npu_device : remote_handle64{
|
|||
tensor_data_type type;
|
||||
};
|
||||
|
||||
struct tensor_op_spec {
|
||||
tensor_op op;
|
||||
param_type params;
|
||||
};
|
||||
|
||||
struct tensor_update_config {
|
||||
tensor_op op;
|
||||
int32_t params[DEVICE_TENSOR_MAX_OP_PARAMS];
|
||||
param_type params;
|
||||
tensor_handle_t src_handles[DEVICE_TENSOR_MAX_SRC];
|
||||
};
|
||||
|
||||
|
|
@ -90,7 +106,7 @@ interface npu_device : remote_handle64{
|
|||
);
|
||||
|
||||
AEEResult device_support_op(
|
||||
in tensor_op op,
|
||||
in tensor_op_spec op_spec,
|
||||
in tensor_spec dst,
|
||||
in sequence<tensor_spec> srcs,
|
||||
rout boolean is_supported
|
||||
|
|
|
|||
Loading…
Reference in New Issue