feat: perf opt gemv phase2 (#58)

* Add power management utilities to NPU device context and update DCVS settings

* Update DCVS settings in power_utils to use v3 API and enhance power management

* wip

* Enhance dequantization functions by adding load_dequant_table support and updating signatures for improved performance

* use lut

* wip

* fix test failure

* wip

* Refactor load_qual_block_generic to improve block handling and optimize vector operations

* Enhance load_dual_block_generic and load_qual_block_generic to accept a mask parameter for improved block handling

* Refactor flash_attn_impl to optimize mask l2 prefetch

* wip

* wip

* wip

* wip

* add log

* link against shared libraries instead of static ones

* fix swiglu

* wip

* refactor expf_fix to handle overflow for different data types

* enhance is_glu_op_supported to validate shapes for multiple sources

* wip

* refactor logging macros to use hexagon namespace and improve formatting

* fix printf format error

* wip

* refactor: update static_assert messages for block size validation and add HVX_VectorPred_x3 type alias

* rename

* feat: enhance fa with mask

* wip

* wip

* refactor: replace instances of Q6_V_vzero() with kZeroV for consistency

* wip

* wip

* wip

* fix: improve address alignment check in HVX_Vector handling

* refactor: streamline vector dot product implementations for improved readability

* refactor: q4k add hvx intrinsic impl

* refactor: enhance dequantize_row_q4_K for clarity and performance

* refactor: optimize scale mask usage in dequantization functions for improved performance

* refactor: optimize dequantize_row_q4_K for intrinsic usage and performance improvements

* refactor: move GLU operation implementation into separated file

* sync after swiglu

* wip

* wip

* wip

* feat: increase prc main thread stack size

* fix: replace hardcoded stack size with NPU_THREAD_STACK_SIZE constant

* wip

* feat: add optimized vector operations for exponential and division with overflow handling

* wip

* feat: refactor exponential function to handle overflow and underflow with improved logic

* wip

* wip

* feat: add vector loading and scaling functions for improved performance in block processing

* wip

* feat: optimize block loading by refactoring scale index handling for improved performance

* use Q6_Vb_vlut32_VbVbR_nomatch instead

* feat: enhance scale loading by adding static assertion and restructuring block handling

* wip

* feat: refactor vec_dot_product_mixed_impl for improved clarity and performance

* wip

* feat: simplify vector loading functions and improve alignment handling

* wip

* feat: enhance scale loading mask with quantization block size validation

* wip

* feat: implement make_scale_load_mask function and refactor vector handling in vec_ops

* feat: enhance load_dual_block_generic to include scale indices for improved vector loading

* revert q8 dequant

* wip

* feat: optimize dequantization functions by removing unnecessary masking and updating lookup methods

* wip

* wip

* add qurt_mutex

* Add DMA transfer class and integrate into thread pool

* Enhance DMA transfer functionality by adding support for multiple descriptors and initiating transfers in parallel

* fix dma crash

* fix failed unit tests

* wip

* use alignas

* Improve DMA transfer error handling and update descriptor completion check

* Fix VTCM cache size calculation in element-wise operations

* Add cache clean operations before DMA transfers in element-wise operations

* reduce cache clean operations

* Refactor DMA transfer functions to support 1D operations and rename for clarity

* Enhance DMA transfer functionality by adding 2D submission support and improving descriptor initialization

* Update read buffer method to support forced invalidation and remove unnecessary invalidation calls in element-wise operations

* wip

* Improve DMA transfer handling in mul_mat_gemv_impl by replacing memcpy with initiate_dma_row_transfer and adding wait_for_dma logic

* fix 2d dma

* feat: add DMA plane cache

* rename

* wip

* use memcpy for debug

* fix cache plane calc

* refactor: remove debug logging from mul_mat_impl and optimize cache handling

* rename

* fix 2d dma type

* refactor: enhance DMA transfer handling in mul_mat_gemv_impl and wait functions

* refactor: optimize DMA transfer handling in mul_mat_gemv_impl and wait functions

* wip

* wip

* move op impl into sub dir

* add log

* fix: correct pointer usage in mul_mat_gemv_impl for next plane access

* fix: improve DMA transfer error handling in mul_mat_impl and mul_mat_gemv_impl

* fix: fix crash by using the entire row bytes

* wip

* wip

* fix: prevent parallelization for scalar src1 in is_mul_mat_supported

* fix: add dimension checks for 2D DMA transfers and fallback to 1D if necessary

* wip

* fix: enable thread barrier for mul multiplication operations

* feat: add synchronization checks for tensor operations and update related functions

* wip

* fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations

* Revert "fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations"

This reverts commit af3441e67e706b2e5122369dc160353796867dd3.

* wip

* wip

* add comment

* fix: improve DMA transfer handling in mul_mat_gemv_impl for quantized source tensors

* add log

* try fix mulmat gemv

* wip

* fix: enhance DMA transfer handling in mul_mat_gemv_impl for quantized source tensors

* fix: optimize cache offset calculation and remove redundant swap in mul_mat_gemv_impl

* fix: refactor DMA transfer handling in mul_mat_gemv_impl for improved clarity and maintainability

* wip

* wip

* wip

* fix: enhance mul_mat_impl for improved cache handling and clarity

* fix: refactor tensor unflattening and DMA transfer initialization for improved clarity and type safety

* fix: improve cache handling of quant

* wip

* fix: improve cache handling in mul_mat_impl and mul_mat_gemv_impl for better memory efficiency

* rename

* add load_hexa_block_generic

* wip

* extract dequant block into separated function

* refactor: enhance dequantization functions with table parameter

* fix load_dual_block_generic

* refactor: rename dequantization functions for clarity and enhance block handling

* refactor: simplify dequantization logic by consolidating block handling and removing unused parameters

* wip

* wip

* feat: add make_qs_load_mask function and update load_dual_block_generic to use qs_indices

* fix load_dual_block_generic

* refactor: update load functions to use qs_indices for improved block loading

* wip

* fix: update loop indices and boundary checks to use size_t for better efficiency

* wip

* update make_scale_load_mask, to make it available for q8

* feat: add vec_dot_product_quant_impl for quantized dot product computation

* refactoring: move come quant func to dedicated file

* refactor: rename dequantization functions for clarity and consistency

* wip

* feat: enhance vec_dot_product_quant_impl with dual dequantization and improved assertions

* add vec_dot_product_vqf32_q40_f32

* wip

* wip

* wip

* wip

* implement vec_mpy_qf32_qf32_qf32 function and update vec_dot_product_vqf32_q40_f32 to use it

* wip

* add src0_plane_write_cache_offset

* wip

* enhance mul_mat_f32 to handle NPU_DATA_TYPE_Q4_0 for quantized matrix multiplication

* wip

* wip

* update test func

* refactor mul_mat_gemv_quant_impl to use get_nb for row stride and remove unused test function in init_f16_f32_table

* wip

* Add support for 4-block dequantization in vec_quant and update dot product implementation

* Refactor vec_dot_product_quant_impl to improve variable handling and enhance readability

* Refactor vec_dot_product_quant_impl to replace template function  with inline vector operations

* use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32

* Revert "use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32"

This reverts commit 54839166fddbe40a0392adee5863c59070ccdbe4.

* wip

* improve log print in graph

* Refactor batched_row_dot to accept additional arguments and remove batched_row_dot_with_table

* Refactor synchronization functions to include previous operation and NE type parameters

* Refactor synchronization checks in several operations

* Update synchronization checks to include NPU_OP_COUNT in required conditions

* Add performance tracking to buffer management functions

* add memset

* add log

* fix: update backend device type from ACCEL to IGPU

* fix comment
This commit is contained in:
nullname 2025-10-16 23:21:51 +08:00 committed by GitHub
parent 40893e58c6
commit 36bc6f3213
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 899 additions and 454 deletions

View File

@ -74,27 +74,32 @@ void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread
hexagon::compute_params params = { thread_params, _f16_to_f32_table };
npu_device_tensor_op prev_op = NPU_OP_COUNT;
npu_device_ne_type prev_ne = {};
for (size_t i = 0; i < _tensor_count; ++i) {
auto * dst = _tensors[i];
auto op = dst->get_op();
auto * func = get_compute_func(dst);
auto * dst = _tensors[i];
auto op = dst->get_op();
const auto & ne = dst->get_info().ne;
const auto * op_name = op_get_name(op);
auto * func = get_compute_func(dst);
if (!func) {
DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d not supported\n", (void *) this, i, op);
DEVICE_LOG_ERROR("[%p][%s]graph tensor[%zu] op not supported\n", (void *) this, op_name, i);
return;
}
const bool should_sync = requires_thread_barrier(prev_op, op);
const bool should_sync = requires_thread_barrier(prev_op, prev_ne, op, ne);
if (pool && should_sync) {
// For the last tensor, the thread pool will handle synchronization
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this,
params.get_thread_index(), i, _tensor_count);
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p][%s]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this, op_name,
params.get_thread_index(), i + 1, _tensor_count);
pool->sync_thread();
}
prev_op = op;
memcpy(&prev_ne, &ne, sizeof(prev_ne));
if (!func(dst, &params)) {
DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d compute failed\n", (void *) this, i, op);
DEVICE_LOG_ERROR("[%p][%s]graph tensor[%zu] op %d compute failed\n", (void *) this, op_name, i, op);
}
}
}

View File

@ -375,10 +375,14 @@ bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
return true;
}
bool is_flash_attn_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
bool is_flash_attn_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
NPU_UNUSED(prev_ne);
NPU_UNUSED(op);
NPU_UNUSED(next_op);
return true;
NPU_UNUSED(ne);
return prev_op != NPU_OP_COUNT;
}
} // namespace hexagon

View File

@ -9,6 +9,9 @@ bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
bool is_flash_attn_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
bool is_flash_attn_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne);
} // namespace hexagon

View File

@ -48,8 +48,8 @@ inline void glu_vec_op_f32_f32(const float * src0,
size_t count,
hexagon::HVX_VectorPair_x4 coeff) {
using namespace hexagon::vec;
vec_trans_with_param_impl<float, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f32_f32>(
src0, src1, dst, count, coeff);
vec_trans_with_param_impl<float, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f32_f32>(src0, src1, dst, count,
coeff);
}
template <auto _GluRowFunc, auto _CoeffLoadFunc>
@ -73,8 +73,8 @@ bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) {
const auto total_cols = has_src1 ? src0->get_ne(0) : src0->get_ne(0) / 2;
if (out->get_ne(0) != total_cols) {
DEVICE_LOG_ERROR(
"[hexagon-npu][GLU]out.ne[0] (%ld) != total_cols (%d)\n", (long) out->get_ne(0), (int) total_cols);
DEVICE_LOG_ERROR("[hexagon-npu][GLU]out.ne[0] (%ld) != total_cols (%d)\n", (long) out->get_ne(0),
(int) total_cols);
return false;
}
@ -87,8 +87,7 @@ bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) {
uint8_t * dst_ptr = out->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("[hexagon-npu][GLU]glu_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
(void *) out,
DEVICE_LOG_ERROR("[hexagon-npu][GLU]glu_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
hexagon::get_type_name(out->get_type()));
return false;
}
@ -121,11 +120,8 @@ bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) {
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes);
}
_GluRowFunc(reinterpret_cast<const data_type *>(src0_row),
reinterpret_cast<const data_type *>(src1_row),
reinterpret_cast<data_type *>(dst_row),
static_cast<size_t>(total_cols),
coeff);
_GluRowFunc(reinterpret_cast<const data_type *>(src0_row), reinterpret_cast<const data_type *>(src1_row),
reinterpret_cast<data_type *>(dst_row), static_cast<size_t>(total_cols), coeff);
}
out->release_write_buffer(); // mark the output tensor as modified
@ -142,8 +138,7 @@ bool glu_compute(hexagon::tensor * out, hexagon::compute_params * params) {
}
if (out->get_type() != _DataType) {
DEVICE_LOG_ERROR("GLU op type mismatch: %s vs %s\n",
hexagon::get_type_name(out->get_type()),
DEVICE_LOG_ERROR("GLU op type mismatch: %s vs %s\n", hexagon::get_type_name(out->get_type()),
hexagon::get_type_name(_DataType));
return false;
}
@ -192,16 +187,14 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec,
const auto & src0 = srcs[0];
if (dst->type != src0.type) {
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
hexagon::op_get_name(op),
hexagon::get_type_name(src0.type),
hexagon::get_type_name(dst->type));
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
return false;
}
if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) {
DEVICE_LOG_DEBUG(
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
hexagon::get_type_name(dst->type));
return false;
}
@ -215,9 +208,7 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec,
if (src0.ne[0] / 2 != dst->ne[0] || src0.ne[1] != dst->ne[1] || src0.ne[2] != dst->ne[2] ||
src0.ne[3] != dst->ne[3]) {
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape: src0.ne[0]: %ld, dst.ne[0]: %ld\n",
hexagon::op_get_name(op),
(long) src0.ne[0],
(long) dst->ne[0]);
hexagon::op_get_name(op), (long) src0.ne[0], (long) dst->ne[0]);
return false;
}
}
@ -225,9 +216,14 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec,
return true;
}
bool is_glu_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
bool is_glu_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
NPU_UNUSED(prev_ne);
NPU_UNUSED(op);
return next_op == NPU_OP_MUL_MAT;
NPU_UNUSED(ne);
return prev_op == NPU_OP_MUL_MAT;
}
} // namespace hexagon

View File

@ -11,6 +11,9 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
bool is_glu_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
bool is_glu_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne);
} // namespace hexagon

View File

@ -227,9 +227,15 @@ bool is_element_wise_op_supported(const npu_device_tensor_op_spec * op_spec,
return true;
}
bool is_element_wise_op_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
bool is_element_wise_op_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
NPU_UNUSED(prev_ne);
NPU_UNUSED(op);
return next_op == NPU_OP_MUL_MAT;
NPU_UNUSED(ne);
return prev_op != NPU_OP_ADD && prev_op != NPU_OP_SUB && prev_op != NPU_OP_MUL && prev_op != NPU_OP_RMS_NORM &&
prev_op != NPU_OP_COUNT;
}
void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) {
@ -361,9 +367,15 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
return true;
}
bool is_unary_op_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
bool is_unary_op_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
NPU_UNUSED(prev_ne);
NPU_UNUSED(op);
return next_op == NPU_OP_MUL_MAT;
NPU_UNUSED(ne);
return prev_op != NPU_OP_ADD && prev_op != NPU_OP_SUB && prev_op != NPU_OP_MUL && prev_op != NPU_OP_RMS_NORM &&
prev_op != NPU_OP_COUNT;
}
struct op_capabilities {
@ -457,13 +469,16 @@ compute_func_type get_compute_func(tensor * dst) {
return get_compute_func_impl(dst->get_op(), dst->get_type());
}
bool requires_thread_barrier(npu_device_tensor_op op, npu_device_tensor_op next_op) {
bool requires_thread_barrier(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
if (op >= NPU_OP_COUNT) {
return false;
}
auto requires_thread_barrier_func = kOpCapabilities[op].requires_thread_barrier_func;
return requires_thread_barrier_func && requires_thread_barrier_func(op, next_op);
return requires_thread_barrier_func && requires_thread_barrier_func(prev_op, prev_ne, op, ne);
}
bool support_op(const npu_device_tensor_op_spec * op_spec,

View File

@ -6,7 +6,10 @@ namespace hexagon {
compute_func_type get_compute_func(tensor * dst);
bool requires_thread_barrier(npu_device_tensor_op op, npu_device_tensor_op next_op);
bool requires_thread_barrier(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne);
bool support_op(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,

View File

@ -6,6 +6,12 @@
namespace {
inline std::pair<size_t, size_t> unflatten_i3_i2(size_t idx, const hexagon::tensor * t) {
const auto i3 = idx / t->get_ne(2);
const auto i2 = idx - i3 * t->get_ne(2);
return { i3, i2 };
}
template <typename _T> struct get_data_type {};
template <typename _TData0, typename _TData1>
@ -14,6 +20,12 @@ struct get_data_type<HVX_Vector (*)(const _TData0 *, const _TData1 *, size_t)> {
using data_type1 = _TData1;
};
template <typename _TData0, typename _TData1>
struct get_data_type<HVX_Vector (*)(const _TData0 *, const _TData1 *, size_t, const HVX_Vector)> {
using data_type0 = _TData0;
using data_type1 = _TData1;
};
template <typename _TRet> struct convert_vector {};
template <> struct convert_vector<float> {
@ -55,29 +67,30 @@ inline bool init_dma_transfer(hexagon::compute_params * params,
return true;
}
template <auto _DotFunc>
template <auto _DotFunc, typename... _TExtraArgs>
inline void batched_row_dot(const uint8_t * src0_plane,
const size_t src0_ne0,
const size_t src0_nb1,
const uint8_t * src1_row,
const size_t src1_nb1,
float * dst_row,
const size_t actual_row_count,
const size_t src1_fetch_row_bytes) {
const size_t slice_rows,
const size_t src1_fetch_row_bytes,
_TExtraArgs... args) {
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
size_t i0 = 0;
for (; i0 + 1 < actual_row_count; i0 += 2) {
for (; i0 + 1 < slice_rows; i0 += 2) {
auto * src0_row = src0_plane + i0 * src0_nb1;
// TODO: figure dst how to handle a entire row
auto res0 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
reinterpret_cast<const data_type1 *>(src1_row), src0_ne0);
reinterpret_cast<const data_type1 *>(src1_row), src0_ne0, args...);
// TODO: figure dst how to handle a entire row
auto res1 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_nb1),
reinterpret_cast<const data_type1 *>(src1_row), src0_ne0);
reinterpret_cast<const data_type1 *>(src1_row), src0_ne0, args...);
{
dst_row[i0] = convert_vector<data_type1>::convert(res0);
@ -89,10 +102,10 @@ inline void batched_row_dot(const uint8_t * src0_plane,
hexagon::l2fetch_row(src1_row + src1_nb1, src1_fetch_row_bytes);
}
if (i0 < actual_row_count) {
if (i0 < slice_rows) {
auto * src0_row = src0_plane + i0 * src0_nb1;
auto res = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
reinterpret_cast<const data_type1 *>(src1_row), src0_ne0);
reinterpret_cast<const data_type1 *>(src1_row), src0_ne0, args...);
dst_row[i0] = convert_vector<data_type1>::convert(res);
}
}
@ -105,7 +118,7 @@ inline void mul_mat_impl(hexagon::tensor * src0,
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
const auto src0_actual_row_stride = hexagon::get_dequantized_row_size(src0);
const auto src0_row_stride = hexagon::get_dequantized_row_size(src0);
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
if (_IsSrcQuantized && dequantize_row_func == nullptr) {
@ -130,7 +143,8 @@ inline void mul_mat_impl(hexagon::tensor * src0,
}
if (start_end_plane.second <= start_end_plane.first || start_end_row.second <= start_end_row.first ||
start_end_element.second <= start_end_element.first) {
start_end_element.second <= start_end_element.first || start_end_plane.first < 0 || start_end_row.first < 0 ||
start_end_element.first < 0) {
DEVICE_LOG_DEBUG(
"mul_mat_impl: no work to do, start_end_plane: (%lld, %lld), start_end_row: (%lld, %lld), "
"start_end_element: (%lld, %lld)\n",
@ -142,12 +156,12 @@ inline void mul_mat_impl(hexagon::tensor * src0,
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
// cache the src0 plane in VTCM
const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0));
const size_t src1_actual_row_stride = hexagon::get_aligned_size(src1->get_nb(1));
const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0));
const size_t src1_row_stride = hexagon::get_aligned_size(src1->get_nb(1));
// TODO: figure out why we have to add padding after src0 plane cache
const size_t src0_plane_slice_row_count =
std::min<size_t>((params->get_vtcm_quota_size() - src1_actual_row_stride) / (src0_actual_row_stride * 2),
std::min<size_t>((params->get_vtcm_quota_size() - src1_row_stride) / (src0_row_stride * 2),
start_end_element.second - start_end_element.first);
uint8_t * src0_plane_read_cache_ptr = nullptr;
uint8_t * src0_plane_write_cache_ptr = nullptr;
@ -156,13 +170,13 @@ inline void mul_mat_impl(hexagon::tensor * src0,
const uint8_t * last_read_cached_plane_ptr = nullptr;
{
const size_t src0_plane_cache_size = src0_actual_row_stride * src0_plane_slice_row_count;
const size_t src0_plane_cache_size = src0_row_stride * src0_plane_slice_row_count;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2);
if (!src0_plane_read_cache_ptr) {
DEVICE_LOG_ERROR(
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
"src0_actual_row_stride: %zu, will fallback to mem cache\n",
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_stride);
"src0_row_stride: %zu, will fallback to mem cache\n",
src0_plane_cache_size, src0_plane_slice_row_count, src0_row_stride);
return;
}
@ -173,11 +187,11 @@ inline void mul_mat_impl(hexagon::tensor * src0,
}
DEVICE_LOG_DEBUG(
"[%d]mul_mat_impl, src0_actual_row_stride:%zu, valid_src0_row_bytes:%zu, src_nb0:%zu, "
"[%d]mul_mat_impl, src0_row_stride:%zu, valid_src0_row_bytes:%zu, src_nb0:%zu, "
"slice_row_count:%zu, write_cache_offset: %zu, "
"total_planes:%lld, planes:[%d,%d), rows:[%d,%d), elems:[%d,%d), is_quant:%d, "
"vtcm_mem:%p(%zu)\n",
(int) params->get_thread_index(), src0_actual_row_stride, valid_src0_row_bytes, (size_t) src0->get_nb(1),
(int) params->get_thread_index(), src0_row_stride, valid_src0_row_bytes, (size_t) src0->get_nb(1),
src0_plane_slice_row_count, src0_plane_write_cache_offset, total_planes, (int) start_end_plane.first,
(int) start_end_plane.second, (int) start_end_row.first, (int) start_end_row.second,
(int) start_end_element.first, (int) start_end_element.second, _IsSrcQuantized,
@ -205,7 +219,7 @@ inline void mul_mat_impl(hexagon::tensor * src0,
last_write_cached_plane_ptr = src0_plane;
}
const size_t valid_row1_bytes =
const size_t valid_src1_row_bytes =
src0->get_ne(0) * sizeof(data_type1); // src0 and src1 should have the same element count in the 1st dimension
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
@ -216,19 +230,19 @@ inline void mul_mat_impl(hexagon::tensor * src0,
return;
}
const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
const uint8_t * src1_ptr = src1->get_read_buffer();
for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) {
const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
for (size_t ip = start_end_plane.first; ip < size_t(start_end_plane.second); ip++) {
const auto [i3, i2] = unflatten_i3_i2(ip, dst);
const auto * src1_plane = src1_ptr + i3 * src1->get_nb(3) + i2 * src1->get_nb(2);
auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2);
const uint8_t * src0_plane_base = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2);
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
for (size_t col_idx = start_end_element.first; col_idx < size_t(start_end_element.second);
col_idx += src0_plane_slice_row_count) {
const uint8_t * src0_plane = src0_plane_base + col_idx * src0->get_nb(1);
const int64_t actual_row_count =
std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - col_idx); // number of rows in this slice
const size_t slice_rows =
std::min<size_t>(src0_plane_slice_row_count,
start_end_element.second - col_idx); // number of rows in this slice
{
const uint8_t * src0_next_plane = last_write_cached_plane_ptr;
@ -272,9 +286,9 @@ inline void mul_mat_impl(hexagon::tensor * src0,
if (last_read_cached_plane_ptr != src0_plane) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
const uint8_t * src0_quant_plane = src0_plane_read_cache_ptr + src0_plane_write_cache_offset;
for (int64_t ir = 0; ir < actual_row_count; ir++) {
for (size_t ir = 0; ir < slice_rows; ir++) {
auto * src0_row = src0_quant_plane + ir * src0->get_nb(1);
auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_stride;
auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_row_stride;
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
src0->get_ne(0), dequant_table);
}
@ -284,16 +298,16 @@ inline void mul_mat_impl(hexagon::tensor * src0,
last_read_cached_plane_ptr = src0_plane;
if (start_end_row.second > start_end_row.first) {
hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_row1_bytes);
hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_src1_row_bytes);
}
for (int64_t i1 = start_end_row.first; i1 < start_end_row.second; i1++) {
for (size_t i1 = start_end_row.first; i1 < size_t(start_end_row.second); i1++) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dot);
auto * src1_row = src1_plane + i1 * src1->get_nb(1);
auto * dst_row = reinterpret_cast<float *>(dst_plane + i1 * dst->get_nb(1)) + col_idx;
batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_stride, src1_row,
src1->get_nb(1), dst_row, actual_row_count,
(ip + 1 < start_end_plane.second) ? valid_row1_bytes : 0);
batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_row_stride, src1_row,
src1->get_nb(1), dst_row, slice_rows,
(ip + 1 < start_end_plane.second) ? valid_src1_row_bytes : 0);
}
}
}
@ -309,25 +323,22 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
const auto src0_actual_row_stride = hexagon::get_dequantized_row_size(src0);
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
if (_IsSrcQuantized && dequantize_row_func == nullptr) {
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
return;
}
auto start_end_element = std::pair<int64_t, int64_t>{ 0, dst->get_ne(0) };
if (dst->get_ne(0) >= params->get_thread_count()) {
start_end_element = params->get_work_slice(dst->get_ne(0));
} else {
if (dst->get_ne(0) < params->get_thread_count()) {
DEVICE_LOG_ERROR("Unsupported src1 tensor shape for gemv: %s, ne: %lldx%lldx%lldx%lld\n",
hexagon::get_type_name(src1->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2),
src1->get_ne(3));
return;
}
if (start_end_element.second <= start_end_element.first) {
const auto start_end_element = params->get_work_slice(dst->get_ne(0));
if (start_end_element.second <= start_end_element.first || start_end_element.first < 0) {
DEVICE_LOG_DEBUG(
"mul_mat_gemv_impl: no work to do, start_end_plane: [0, 1), start_end_row: [0, 1), "
"start_end_element: [%lld, %lld)\n",
@ -335,13 +346,14 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
return;
}
const auto src0_row_stride = hexagon::get_dequantized_row_size(src0);
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0));
// cache the src0 plane in VTCM
const size_t src1_actual_row_stride = hexagon::get_aligned_size(src1->get_nb(1));
const size_t src1_row_stride = hexagon::get_aligned_size(src1->get_nb(1));
const size_t src0_plane_slice_row_count =
std::min<size_t>((params->get_vtcm_quota_size() - src1_actual_row_stride) / (src0_actual_row_stride * 2),
std::min<size_t>((params->get_vtcm_quota_size() - src1_row_stride) / (src0_row_stride * 2),
start_end_element.second - start_end_element.first);
uint8_t * src0_plane_read_cache_ptr = nullptr;
@ -351,13 +363,13 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
{
const size_t src0_plane_cache_size = src0_actual_row_stride * src0_plane_slice_row_count;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_actual_row_stride);
const size_t src0_plane_cache_size = src0_row_stride * src0_plane_slice_row_count;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_row_stride);
if (!src0_plane_read_cache_ptr) {
DEVICE_LOG_ERROR(
"mul_mat_gemv_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
"src0_actual_row_stride: %zu, will fallback to mem cache\n",
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_stride);
"src0_row_stride: %zu, will fallback to mem cache\n",
src0_plane_cache_size, src0_plane_slice_row_count, src0_row_stride);
return;
}
@ -369,9 +381,9 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
}
DEVICE_LOG_DEBUG(
"mul_mat_gemv_impl: src0_actual_row_stride: %zu, src0_plane_slice_row_count: %zu, "
"mul_mat_gemv_impl: src0_row_stride: %zu, src0_plane_slice_row_count: %zu, "
"src0_plane_write_cache_offset: %zu, src0.nb[1]: %d, is_quantized: %d, vtcm_mem: %p(%zu)\n",
src0_actual_row_stride, src0_plane_slice_row_count, src0_plane_write_cache_offset, int(src0->get_nb(1)),
src0_row_stride, src0_plane_slice_row_count, src0_plane_write_cache_offset, int(src0->get_nb(1)),
_IsSrcQuantized, (void *) src0_plane_read_cache_ptr, src0_plane_cache_size);
}
@ -382,8 +394,7 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
return;
}
auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
const uint8_t * src1_ptr = src1->get_read_buffer();
const uint8_t * src1_ptr = src1->get_read_buffer();
{
if (!params->initiate_dma_row_transfer(src1_ptr, src1_row_cache_ptr, src1->get_ne(0) * sizeof(data_type1))) {
@ -392,9 +403,9 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
}
const uint8_t * src0_plane = src0_ptr + start_end_element.first * src0->get_nb(1);
const int64_t next_row_count =
std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - start_end_element.first); // number of rows in this slice
const size_t next_row_count =
std::min<size_t>(src0_plane_slice_row_count,
start_end_element.second - start_end_element.first); // number of rows in this slice
params->wait_for_dma();
if (!init_dma_transfer<_IsSrcQuantized>(
@ -406,22 +417,23 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
}
}
const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
{
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
for (size_t col_idx = start_end_element.first; col_idx < size_t(start_end_element.second);
col_idx += src0_plane_slice_row_count) {
const int64_t actual_row_count =
std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - col_idx); // number of rows in this slice
const auto next_col_idx = col_idx + src0_plane_slice_row_count;
const size_t slice_rows =
std::min<size_t>(src0_plane_slice_row_count,
start_end_element.second - col_idx); // number of rows in this slice
const size_t next_col_idx = col_idx + src0_plane_slice_row_count;
std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr);
params->wait_for_dma();
if (next_col_idx < start_end_element.second) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, dma);
const uint8_t * src0_next_plane = src0_ptr + next_col_idx * src0->get_nb(1);
const int64_t next_row_count =
std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - next_col_idx); // number of rows in this slice
const size_t next_row_count =
std::min<size_t>(src0_plane_slice_row_count,
start_end_element.second - next_col_idx); // number of rows in this slice
if (!init_dma_transfer<_IsSrcQuantized>(
params, src0_next_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset,
valid_src0_row_bytes, next_row_count, src0->get_nb(1), src0->get_nb(1))) {
@ -435,9 +447,9 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
if constexpr (_IsSrcQuantized) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
const uint8_t * src0_quant_plane = src0_plane_read_cache_ptr + src0_plane_write_cache_offset;
for (int64_t ir = 0; ir < actual_row_count; ir++) {
for (size_t ir = 0; ir < slice_rows; ir++) {
auto * src0_row = src0_quant_plane + ir * src0->get_nb(1);
auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_stride;
auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_row_stride;
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
src0->get_ne(0), dequant_table);
}
@ -446,8 +458,135 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
{
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dot);
auto * dst_row = reinterpret_cast<float *>(dst_ptr) + col_idx;
batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_stride,
src1_row_cache_ptr, src1->get_nb(1), dst_row, actual_row_count, 0);
batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_row_stride,
src1_row_cache_ptr, src1->get_nb(1), dst_row, slice_rows, 0);
}
}
}
dst->release_write_buffer(); // mark the output tensor as modified
}
template <auto _DotFunc>
inline void mul_mat_gemv_quant_impl(hexagon::tensor * src0,
hexagon::tensor * src1,
hexagon::tensor * dst,
hexagon::compute_params * params) {
// TODO: merge with mul_mat_gemv_impl?
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
if (dst->get_ne(0) < params->get_thread_count()) {
DEVICE_LOG_ERROR("Unsupported src1 tensor shape for gemv: %s, ne: %lldx%lldx%lldx%lld\n",
hexagon::get_type_name(src1->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2),
src1->get_ne(3));
return;
}
const auto src0_row_stride = src0->get_nb(1);
const auto start_end_element = params->get_work_slice(dst->get_ne(0));
if (start_end_element.second <= start_end_element.first || start_end_element.first < 0) {
DEVICE_LOG_DEBUG(
"mul_mat_gemv_quant_impl: no work to do, start_end_plane: [0, 1), start_end_row: [0, 1), "
"start_end_element: [%lld, %lld)\n",
start_end_element.first, start_end_element.second);
return;
}
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
// cache the src0 plane in VTCM
const size_t src1_row_stride = hexagon::get_aligned_size(src1->get_nb(1));
const size_t src0_plane_slice_row_count =
std::min<size_t>((params->get_vtcm_quota_size() - src1_row_stride) / (src0_row_stride * 2),
start_end_element.second - start_end_element.first);
uint8_t * src0_plane_read_cache_ptr = nullptr;
uint8_t * src0_plane_write_cache_ptr = nullptr;
uint8_t * src1_row_cache_ptr = nullptr;
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
{
const size_t src0_plane_cache_size = src0_row_stride * src0_plane_slice_row_count;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_row_stride);
if (!src0_plane_read_cache_ptr) {
DEVICE_LOG_ERROR(
"mul_mat_gemv_quant_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: "
"%zu, "
"src0_row_stride: %zu, will fallback to mem cache\n",
src0_plane_cache_size, src0_plane_slice_row_count, src0_row_stride);
return;
}
src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size;
src1_row_cache_ptr = src0_plane_write_cache_ptr + src0_plane_cache_size;
DEVICE_LOG_DEBUG(
"mul_mat_gemv_quant_impl: src0_row_stride: %zu, src0_plane_slice_row_count: %zu, src0.nb[1]: %d, vtcm_mem: "
"%p(%zu)\n",
src0_row_stride, src0_plane_slice_row_count, int(src0->get_nb(1)), (void *) src0_plane_read_cache_ptr,
src0_plane_cache_size);
}
uint8_t * dst_ptr = dst->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("mul_mat_gemv_quant_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst,
hexagon::get_type_name(dst->get_type()));
return;
}
const uint8_t * src1_ptr = src1->get_read_buffer();
{
if (!params->initiate_dma_row_transfer(src1_ptr, src1_row_cache_ptr, src1->get_ne(0) * sizeof(data_type1))) {
DEVICE_LOG_ERROR("mul_mat_gemv_quant_impl: failed to initiate dma transfer for src1\n");
return;
}
const uint8_t * src0_plane = src0_ptr + start_end_element.first * src0_row_stride;
const size_t next_row_count =
std::min<size_t>(src0_plane_slice_row_count,
start_end_element.second - start_end_element.first); // number of rows in this slice
params->wait_for_dma();
if (!init_dma_transfer<true>(params, src0_plane, src0_plane_write_cache_ptr, src0_row_stride, next_row_count,
src0_row_stride, src0_row_stride)) {
DEVICE_LOG_ERROR("mul_mat_gemv_quant_impl: failed to initiate dma plane transfer for src0 plane\n");
return;
}
}
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
{
for (size_t col_idx = start_end_element.first; col_idx < size_t(start_end_element.second);
col_idx += src0_plane_slice_row_count) {
const size_t slice_rows =
std::min<size_t>(src0_plane_slice_row_count,
start_end_element.second - col_idx); // number of rows in this slice
const size_t next_col_idx = col_idx + src0_plane_slice_row_count;
std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr);
params->wait_for_dma();
if (next_col_idx < start_end_element.second) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dma);
const uint8_t * src0_next_plane = src0_ptr + next_col_idx * src0_row_stride;
const size_t next_row_count =
std::min<size_t>(src0_plane_slice_row_count,
start_end_element.second - next_col_idx); // number of rows in this slice
if (!init_dma_transfer<true>(params, src0_next_plane, src0_plane_write_cache_ptr, src0_row_stride,
next_row_count, src0_row_stride, src0_row_stride)) {
DEVICE_LOG_ERROR("mul_mat_gemv_quant_impl: failed to continue dma plane transfer for src0 plane\n");
return;
}
}
{
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dot);
auto * dst_row = reinterpret_cast<float *>(dst_ptr) + col_idx;
batched_row_dot<_DotFunc, const HVX_Vector>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_row_stride,
src1_row_cache_ptr, src1->get_nb(1), dst_row, slice_rows, 0,
dequant_table);
}
}
}
@ -641,8 +780,13 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
switch (src1->get_type()) {
case NPU_DATA_TYPE_F32:
if (is_src0_quantized) {
kMulMatF16F32QuantizedFuncs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, true, is_gemv) +
base_index](src0, src1, out, params);
if (is_gemv && src0->get_type() == NPU_DATA_TYPE_Q4_0) {
// TODO: move to array
mul_mat_gemv_quant_impl<hexagon::vec_dot_product_vqf32_q40_f32>(src0, src1, out, params);
} else {
kMulMatF16F32QuantizedFuncs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, true, is_gemv) +
base_index](src0, src1, out, params);
}
} else if (src0->get_type() == NPU_DATA_TYPE_F16) {
kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, true, is_gemv) + base_index](
src0, src1, out, params);
@ -664,7 +808,7 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
break;
}
DEVICE_LOG_ERROR("Unsupported src1 tensor type: %s\n", get_type_name(src1->get_type()));
DEVICE_LOG_ERROR("[MUL_MAT]Unsupported src1 tensor type: %s\n", get_type_name(src1->get_type()));
return false;
}
@ -741,10 +885,15 @@ bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec,
return true;
}
bool is_mul_mat_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
bool is_mul_mat_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
NPU_UNUSED(prev_op);
NPU_UNUSED(prev_ne);
NPU_UNUSED(op);
NPU_UNUSED(next_op);
return true;
NPU_UNUSED(ne);
return prev_op != NPU_OP_MUL_MAT || !is_same_shape(prev_ne, ne);
}
} // namespace hexagon

View File

@ -12,6 +12,9 @@ bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
bool is_mul_mat_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
bool is_mul_mat_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne);
} // namespace hexagon

View File

@ -392,9 +392,14 @@ bool is_rope_supported(const npu_device_tensor_op_spec * op_spec,
return true; // ROPE operation is not supported yet
}
bool is_rope_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
bool is_rope_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne) {
NPU_UNUSED(prev_op);
NPU_UNUSED(prev_ne);
NPU_UNUSED(op);
NPU_UNUSED(next_op);
NPU_UNUSED(ne);
return false;
}

View File

@ -9,6 +9,9 @@ bool is_rope_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
bool is_rope_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
bool is_rope_required_sync(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne);
} // namespace hexagon

View File

@ -83,6 +83,9 @@ typedef bool (*op_is_supported_func_type)(const npu_device_tensor_op_spec * op_s
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len);
typedef bool (*op_required_sync_func_type)(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
typedef bool (*op_required_sync_func_type)(npu_device_tensor_op prev_op,
const npu_device_ne_type & prev_ne,
npu_device_tensor_op op,
const npu_device_ne_type & ne);
} // namespace hexagon

View File

@ -3,8 +3,6 @@
#include "op_types.hpp" // TODO: remove this include
#include "vec_ops.hpp"
#include <array>
static_assert(sizeof(npu_device_block_q4_k) ==
2 * sizeof(npu_device_fp16_t) + QUANT_K_SCALE_SIZE + QUANT_K_BLOCK_SIZE / 2,
"wrong q4_K block size/padding");
@ -26,168 +24,6 @@ inline npu_device_fp16_t to_fp16(const float src) {
return reinterpret_cast<const npu_device_fp16_t &>(f16_value);
}
template <typename _TStruct, size_t _Count, auto _MemberPtr> inline HVX_Vector load_into_vector(const _TStruct * src) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load");
return *reinterpret_cast<const HVX_UVector *>(&(src->*_MemberPtr));
}
template <typename _TStruct, size_t _Count> inline HVX_Vector load_struct_into_vector(const _TStruct * src) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load");
return *reinterpret_cast<const HVX_UVector *>(src);
}
template <typename _TBlock> inline HVX_Vector load_block_generic(const _TBlock & src) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock), "wrong block size/padding");
return load_into_vector<_TBlock, 1, &_TBlock::qs>(&src);
}
template <typename _TBlock> inline HVX_Vector make_scale_load_mask() {
static_assert(sizeof(_TBlock) + sizeof(npu_device_fp16_t) < 32, "wrong block size/padding");
static_assert(sizeof(_TBlock::qs) == 16 || sizeof(_TBlock::qs) == 32, "wrong quantization block size");
constexpr const size_t kScaleBlockSize = QUANT_BLOCK_SIZE * sizeof(hexagon::dequant_output_type);
// TODO: handle the case that scale not at the start of struct
hexagon::HVX_VectorAlias ret;
for (size_t i = 0; i < QUANT_BLOCK_SIZE; ++i) {
size_t base = i * 2;
ret.u8[base] = 0;
ret.u8[base + 1] = 1;
ret.u8[base + kScaleBlockSize] = sizeof(_TBlock);
ret.u8[base + kScaleBlockSize + 1] = sizeof(_TBlock) + 1;
}
return ret.v;
}
template <typename _TBlock> inline HVX_Vector load_dual_block_generic(const _TBlock * srcs, HVX_VectorPred mask) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs;
HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs);
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale);
return Q6_V_vmux_QVV(mask, blocks, block1);
}
template <typename _TBlock>
inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs,
HVX_VectorPred mask,
const HVX_Vector scale_indices) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs;
hexagon::HVX_Vector_x2 result;
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 2>(srcs);
HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale);
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2);
HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks);
result.val[0] = Q6_V_vmux_QVV(mask, block0, block1);
result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0);
return result;
}
template <typename _TBlock> inline hexagon::HVX_VectorPred_x3 make_quad_block_mask() {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 4, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
hexagon::HVX_VectorPred_x3 mask;
mask.val[0] = Q6_Q_vsetq_R(kSizeOfQs);
mask.val[1] = Q6_Q_vsetq_R(kSizeOfQs * 3);
mask.val[2] = Q6_Q_vsetq_R(kSizeOfQs * 2);
return mask;
}
template <typename _TBlock>
inline hexagon::HVX_Vector_x3 load_qual_block_generic(const _TBlock * srcs,
const hexagon::HVX_VectorPred_x3 mask,
const HVX_Vector scale_indices) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 4, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs;
hexagon::HVX_Vector_x3 result;
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 4>(srcs);
{
HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale);
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2);
HVX_Vector block2 = Q6_V_vror_VR(blocks, kSizeOfScale * 3);
HVX_Vector block3 = Q6_V_vror_VR(blocks, kSizeOfScale * 4);
HVX_Vector block01 = Q6_V_vmux_QVV(mask.val[0], block0, block1);
HVX_Vector block23 = Q6_V_vmux_QVV(mask.val[1], block2, block3);
result.val[0] = Q6_V_vmux_QVV(mask.val[2], block01, block23);
}
{
HVX_Vector scale23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2);
HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks);
scale23 = Q6_Vb_vshuff_Vb(scale23);
result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0);
result.val[2] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale23, 0);
}
return result;
}
template <typename _TBlock>
inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs,
const hexagon::HVX_VectorPred_x3 mask,
const HVX_Vector scale_indices) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 6, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs;
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 6>(srcs);
hexagon::HVX_Vector_x5 result;
{
HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale);
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2);
HVX_Vector block2 = Q6_V_vror_VR(blocks, kSizeOfScale * 3);
HVX_Vector block3 = Q6_V_vror_VR(blocks, kSizeOfScale * 4);
HVX_Vector block4 = Q6_V_vror_VR(blocks, kSizeOfScale + sizeof(_TBlock) * 4);
HVX_Vector block5 = Q6_V_vror_VR(blocks, kSizeOfScale * 2 + sizeof(_TBlock) * 4);
HVX_Vector block01 = Q6_V_vmux_QVV(mask.val[0], block0, block1);
HVX_Vector block23 = Q6_V_vmux_QVV(mask.val[1], block2, block3);
result.val[0] = Q6_V_vmux_QVV(mask.val[2], block01, block23);
result.val[3] = Q6_V_vmux_QVV(mask.val[0], block4, block5); // block45
}
{
HVX_Vector scale23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2);
HVX_Vector scale45 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 4);
HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks);
scale23 = Q6_Vb_vshuff_Vb(scale23);
scale45 = Q6_Vb_vshuff_Vb(scale45);
result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0);
result.val[2] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale23, 0);
result.val[4] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale45, 0);
}
return result;
}
inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) {
// TODO: use intrinsics
if (j < 4) {
@ -448,25 +284,25 @@ void quantize_row_q4_K(const float * src, void * dst, size_t count) {
}
void dequantize_row_q8_0(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector) {
using namespace hexagon::vec::quant;
constexpr const int qk = QUANT_BLOCK_SIZE;
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
const int nb = count / qk;
const auto * src_ptr = reinterpret_cast<const npu_device_block_q8_0 *>(src);
auto * dst_ptr = ((hexagon::dequant_output_type *) dst); // TODO: opt for aligned access
const HVX_VectorPred mask = Q6_Q_vsetq_R(sizeof(npu_device_block_q8_0::qs));
const HVX_VectorPred scale_mask = Q6_Q_vsetq_R(hexagon::kBytesPerVector / 2);
alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices = make_qs_load_mask<npu_device_block_q8_0>();
alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices =
make_scale_load_mask<npu_device_block_q8_0>();
const int nb = count / qk;
const auto * src_ptr = reinterpret_cast<const npu_device_block_q8_0 *>(src);
auto * dst_ptr = ((hexagon::dequant_output_type *) dst); // TODO: opt for aligned access
int i = 0;
for (; i + 1 < nb; i += 2) {
const auto & src0 = src_ptr[i];
const auto & src1 = src_ptr[i + 1];
auto qs = load_dual_block_generic(src_ptr + i, qs_indices, scale_indices);
HVX_Vector scales01 = Q6_V_vmux_QVV(scale_mask, Q6_Vh_vsplat_R(src0.d), Q6_Vh_vsplat_R(src1.d));
HVX_Vector qs = load_dual_block_generic(src_ptr + i, mask);
HVX_Vector q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(Q6_Wh_vunpack_Vb(qs)));
HVX_Vector result = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01);
HVX_Vector q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(Q6_Wh_vunpack_Vb(qs.val[0])));
HVX_Vector result = Q6_Vqf16_vmpy_VhfVhf(q_lo, qs.val[1]);
*reinterpret_cast<HVX_UVector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(result);
dst_ptr += qk * 2;
@ -486,75 +322,17 @@ void dequantize_row_q8_0(const void * src, hexagon::dequant_output_type * dst, s
}
}
template <bool _IsDstAligned>
inline void dequantize_row_q4_0_2blocks(HVX_Vector qs,
HVX_Vector scale01,
HVX_Vector table,
hexagon::dequant_output_type * dst_ptr) {
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2));
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(qp0), scale01);
q_lo = Q6_Vhf_equals_Vqf16(q_lo);
if constexpr (_IsDstAligned) {
*reinterpret_cast<HVX_Vector *>(dst_ptr) = q_lo;
} else {
*reinterpret_cast<HVX_UVector *>(dst_ptr) = q_lo;
}
}
template <bool _IsDstAligned>
inline void dequantize_row_q4_0_4blocks(HVX_Vector qs,
HVX_Vector scale01,
HVX_Vector scale23,
HVX_Vector table,
hexagon::dequant_output_type * dst_ptr) {
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4));
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_V_lo_W(qp0);
q_hi = Q6_V_hi_W(qp0);
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scale01);
q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scale23);
q_lo = Q6_Vhf_equals_Vqf16(q_lo);
q_hi = Q6_Vhf_equals_Vqf16(q_hi);
if constexpr (_IsDstAligned) {
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = q_lo;
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = q_hi;
} else {
reinterpret_cast<HVX_UVector *>(dst_ptr)[0] = q_lo;
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = q_hi;
}
}
template <bool _IsDstAligned>
void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector table) {
using namespace hexagon::vec::quant;
constexpr const size_t kElemsPerVec = hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type);
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
constexpr const int qk = QUANT_BLOCK_SIZE;
static_assert(qk % 2 == 0, "qk must be even");
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
static const auto load_masks = make_quad_block_mask<npu_device_block_q4_0>();
alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices = make_qs_load_mask<npu_device_block_q4_0>();
alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices =
make_scale_load_mask<npu_device_block_q4_0>();
@ -565,21 +343,41 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * d
int i = 0;
for (; i + 5 < nb; i += 6) {
auto qs = load_hexa_block_generic(src_ptr + i, load_masks, scale_indices);
dequantize_row_q4_0_4blocks<_IsDstAligned>(qs.val[0], qs.val[1], qs.val[2], table, dst_ptr);
dequantize_row_q4_0_2blocks<_IsDstAligned>(qs.val[3], qs.val[4], table, dst_ptr + kElemsPerVec * 2);
auto qs = load_hexa_block_generic(src_ptr + i, qs_indices, scale_indices);
auto res01 = dequantize_vec_q40_qf16_4blocks(qs.val[0], qs.val[1], qs.val[2], table);
auto res2 = dequantize_vec_q40_qf16_2blocks(qs.val[3], qs.val[4], table);
if constexpr (_IsDstAligned) {
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]);
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]);
reinterpret_cast<HVX_Vector *>(dst_ptr)[2] = Q6_Vhf_equals_Vqf16(res2);
} else {
reinterpret_cast<HVX_UVector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]);
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]);
reinterpret_cast<HVX_UVector *>(dst_ptr)[2] = Q6_Vhf_equals_Vqf16(res2);
}
dst_ptr += kElemsPerVec * 3;
}
for (; i + 3 < nb; i += 4) {
auto qs = load_qual_block_generic(src_ptr + i, load_masks, scale_indices);
dequantize_row_q4_0_4blocks<_IsDstAligned>(qs.val[0], qs.val[1], qs.val[2], table, dst_ptr);
auto qs = load_qual_block_generic(src_ptr + i, qs_indices, scale_indices);
auto res01 = dequantize_vec_q40_qf16_4blocks(qs.val[0], qs.val[1], qs.val[2], table);
if constexpr (_IsDstAligned) {
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]);
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]);
} else {
reinterpret_cast<HVX_UVector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(res01.val[0]);
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(res01.val[1]);
}
dst_ptr += kElemsPerVec * 2;
}
for (; i + 1 < nb; i += 2) {
auto qs = load_dual_block_generic(src_ptr + i, load_masks.val[0], scale_indices);
dequantize_row_q4_0_2blocks<_IsDstAligned>(qs.val[0], qs.val[1], table, dst_ptr);
auto res = load_dequant_vec_q40_qf16_2blocks(src_ptr + i, qs_indices, scale_indices, table);
if constexpr (_IsDstAligned) {
*reinterpret_cast<HVX_Vector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(res);
} else {
*reinterpret_cast<HVX_UVector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(res);
}
dst_ptr += kElemsPerVec;
}
@ -611,12 +409,8 @@ HVX_Vector load_dequant_table_q4_0() {
constexpr const int kQ4ZeroPoint = 8; // zero point for q4_0 quantization
static_assert(kTableSize <= hexagon::kBytesPerVector / sizeof(__fp16), "table too large");
static const HVX_Vector result = []() -> HVX_Vector {
alignas(hexagon::kBytesPerVector) union {
HVX_Vector v;
__fp16 f16[sizeof(HVX_Vector) / sizeof(__fp16)];
} table;
alignas(hexagon::kBytesPerVector) static const HVX_Vector result = []() -> HVX_Vector {
alignas(hexagon::kBytesPerVector) hexagon::HVX_VectorAlias table;
table.v = Q6_V_vzero();
for (int i = 0; i < kTableSize; ++i) {
table.f16[i * 2] = i - kQ4ZeroPoint; // TODO: vectorize this?
@ -640,12 +434,8 @@ HVX_Vector load_dequant_table_q4_k() {
constexpr const int kTableSize = 1 << 4; // 4 bits per value, 16 values
static_assert(kTableSize <= hexagon::kBytesPerVector / sizeof(__fp16), "table too large");
const static HVX_Vector result = []() -> HVX_Vector {
alignas(hexagon::kBytesPerVector) union {
HVX_Vector v;
__fp16 f16[sizeof(HVX_Vector) / sizeof(__fp16)];
} table;
alignas(hexagon::kBytesPerVector) static const HVX_Vector result = []() -> HVX_Vector {
alignas(hexagon::kBytesPerVector) hexagon::HVX_VectorAlias table;
table.v = Q6_V_vzero();
for (int i = 0; i < kTableSize; ++i) {
table.f16[i * 2] = i; // TODO: vectorize this?

View File

@ -19,6 +19,7 @@ using HVX_Vector_x2 = HEXAGON_pack<HVX_Vector, 2>;
using HVX_Vector_x3 = HEXAGON_pack<HVX_Vector, 3>;
using HVX_Vector_x4 = HEXAGON_pack<HVX_Vector, 4>;
using HVX_Vector_x5 = HEXAGON_pack<HVX_Vector, 5>;
using HVX_VectorPair_x2 = HEXAGON_pack<HVX_VectorPair, 2>;
using HVX_VectorPair_x4 = HEXAGON_pack<HVX_VectorPair, 4>;
using HVX_VectorPred_x3 = HEXAGON_pack<HVX_VectorPred, 3>;
@ -203,6 +204,7 @@ inline HVX_Vector hvx_passthru(HVX_Vector src, HVX_UVector *, HVX_Vector) {
#include "vec_math.inl"
#include "vec_ops.inl"
#include "vec_quant.inl"
namespace hexagon {
@ -268,8 +270,8 @@ inline HVX_Vector vec_dot_product_vqf32_f32_f32(const float * src0, const float
inline HVX_Vector vec_dot_product_aligned_vqf32_f32_f32(const float * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
return vec_dot_product_aligned_impl<float, HVX_Vector, vec_mpy_qf32, vec_add_qf32, vec_reduction_qf32>(
src0, src1, count);
return vec_dot_product_aligned_impl<float, HVX_Vector, vec_mpy_qf32, vec_add_qf32, vec_reduction_qf32>(src0, src1,
count);
}
inline float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) {
@ -279,8 +281,8 @@ inline float vec_dot_product_f32_f32(const float * src0, const float * src1, siz
inline float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
return vec_dot_product_aligned_impl<float, float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(
src0, src1, count);
return vec_dot_product_aligned_impl<float, float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1,
count);
}
inline bool is_f32_f32_dot_product_aligned(const float * src0, const float * src1, size_t count) {
@ -326,13 +328,8 @@ inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0,
inline HVX_Vector vec_dot_product_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
using namespace hexagon::vec::math;
return vec_dot_product_mixed_impl<npu_device_fp16_t,
float,
HVX_Vector,
hvx_vsf_convert_vhf,
vec_mpy_qf32,
vec_add_qf32,
vec_reduction_qf32>(src0, src1, count);
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, HVX_Vector, hvx_vsf_convert_vhf, vec_mpy_qf32,
vec_add_qf32, vec_reduction_qf32>(src0, src1, count);
}
inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0,
@ -340,37 +337,39 @@ inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t
size_t count) {
using namespace hexagon::vec;
using namespace hexagon::vec::math;
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t,
float,
HVX_Vector,
hvx_vsf_convert_vhf,
vec_mpy_qf32,
vec_add_qf32,
vec_reduction_qf32>(src0, src1, count);
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, HVX_Vector, hvx_vsf_convert_vhf, vec_mpy_qf32,
vec_add_qf32, vec_reduction_qf32>(src0, src1, count);
}
inline float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
using namespace hexagon::vec::math;
return vec_dot_product_mixed_impl<npu_device_fp16_t,
float,
float,
hvx_vsf_convert_vhf,
vec_mpy_qf32,
vec_add_qf32,
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
vec_reduction_f32_qf32>(src0, src1, count);
}
inline float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
using namespace hexagon::vec;
using namespace hexagon::vec::math;
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t,
float,
float,
hvx_vsf_convert_vhf,
vec_mpy_qf32,
vec_add_qf32,
vec_reduction_f32_qf32>(src0, src1, count);
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, float, hvx_vsf_convert_vhf, vec_mpy_qf32,
vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
}
inline HVX_Vector vec_dot_product_vqf32_q40_f32(const npu_device_block_q4_0 * src0,
const float * src1,
size_t count,
const HVX_Vector table) {
using namespace hexagon::vec;
using namespace hexagon::vec::math;
using namespace hexagon::vec::quant;
alignas(hexagon::kBytesPerVector) static const HVX_Vector qs_indices = make_qs_load_mask<npu_device_block_q4_0>();
alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices =
make_scale_load_mask<npu_device_block_q4_0>();
return vec_dot_product_quant_impl<npu_device_block_q4_0, float, HVX_Vector, load_dequant_vec_q40_qf32_4blocks,
load_dequant_vec_q40_qf32_2blocks, load_dequant_vec_q40_qf32_1block,
vec_reduction_qf32>(src0, src1, count, qs_indices, scale_indices, table);
}
inline bool is_f16_f32_dot_product_aligned(const npu_device_fp16_t * src0, const float * src1, size_t count) {

View File

@ -4,7 +4,9 @@
#include <hexagon_types.h>
#include <cassert>
#include <cstdint>
#include <type_traits>
namespace hexagon::vec {
@ -378,6 +380,156 @@ inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem
return _ReduceFunc(_AddFunc(sum0, sum1));
}
template <typename _TQuantElem0,
typename _TElem1,
typename _TRet,
HVX_VectorPair_x2 (*_DequantQuadFunc)(const _TQuantElem0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table),
HVX_VectorPair (*_DequantDualFunc)(const _TQuantElem0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table),
HVX_Vector (*_DequantFunc)(const _TQuantElem0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table),
_TRet (*_ReduceFunc)(HVX_Vector)>
inline _TRet vec_dot_product_quant_impl(const _TQuantElem0 * src0,
const _TElem1 * src1,
size_t count,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table) {
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem1);
static_assert(std::is_same_v<_TQuantElem0, npu_device_block_q4_0> ||
std::is_same_v<_TQuantElem0, npu_device_block_q4_k> ||
std::is_same_v<_TQuantElem0, npu_device_block_q8_0>,
"Element type mismatch: _TQuantElem0 must be a supported quantization block type");
static_assert(QUANT_BLOCK_SIZE == kElementsPerVector,
"Quant block size mismatch: QUANT_BLOCK_SIZE must be equal to kElementsPerVector");
assert(count % kElementsPerVector == 0 && "Count must be a multiple of kElementsPerVector");
const HVX_Vector kZeroV = Q6_V_vzero();
const _TQuantElem0 * src0_ptr = src0;
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector;
HVX_Vector prev1 = *src1_vec_ptr++;
HVX_Vector sum = kZeroV;
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
HVX_Vector sum0 = kZeroV;
HVX_Vector sum1 = kZeroV;
while (src1_vec_ptr_end - src1_vec_ptr > 3) {
HVX_VectorPair_x2 s01 = _DequantQuadFunc(src0_ptr, qs_indices, scale_indices, table);
HVX_Vector curr100 = src1_vec_ptr[0];
HVX_Vector curr101 = src1_vec_ptr[1];
HVX_Vector curr110 = src1_vec_ptr[2];
HVX_Vector curr111 = src1_vec_ptr[3];
HVX_Vector l00 = Q6_V_lo_W(s01.val[0]);
HVX_Vector l10 = Q6_V_valign_VVR(curr100, prev1, (size_t) src1);
HVX_Vector l01 = Q6_V_lo_W(s01.val[1]);
HVX_Vector l11 = Q6_V_valign_VVR(curr110, curr101, (size_t) src1);
HVX_Vector h00 = Q6_V_hi_W(s01.val[0]);
HVX_Vector h10 = Q6_V_valign_VVR(curr101, curr100, (size_t) src1);
HVX_Vector h01 = Q6_V_hi_W(s01.val[1]);
HVX_Vector h11 = Q6_V_valign_VVR(curr111, curr110, (size_t) src1);
l10 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l10);
l11 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l11);
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l00, l10);
HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(l01, l11);
h10 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h10);
h11 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h11);
HVX_Vector mpy2 = Q6_Vqf32_vmpy_Vqf32Vqf32(h00, h10);
HVX_Vector mpy3 = Q6_Vqf32_vmpy_Vqf32Vqf32(h01, h11);
prev1 = curr111;
sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum0);
sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sum1);
sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy2, sum0);
sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy3, sum1);
src0_ptr += 4;
src1_vec_ptr += 4;
}
while (src1_vec_ptr_end - src1_vec_ptr > 1) {
HVX_VectorPair s0 = _DequantDualFunc(src0_ptr, qs_indices, scale_indices, table);
HVX_Vector curr10 = src1_vec_ptr[0];
HVX_Vector curr11 = src1_vec_ptr[1];
HVX_Vector l0 = Q6_V_lo_W(s0);
HVX_Vector l1 = Q6_V_valign_VVR(curr10, prev1, (size_t) src1);
HVX_Vector h0 = Q6_V_hi_W(s0);
HVX_Vector h1 = Q6_V_valign_VVR(curr11, curr10, (size_t) src1);
l1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, l1);
h1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, h1);
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(l0, l1);
HVX_Vector mpy1 = Q6_Vqf32_vmpy_Vqf32Vqf32(h0, h1);
prev1 = curr11;
sum0 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum0);
sum1 = Q6_Vqf32_vadd_Vqf32Vqf32(mpy1, sum1);
src0_ptr += 2;
src1_vec_ptr += 2;
}
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum0, sum1);
}
if (src1_vec_ptr_end - src1_vec_ptr > 0) {
HVX_Vector curr1 = *src1_vec_ptr++;
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
HVX_Vector s0 = _DequantFunc(src0_ptr++, qs_indices, scale_indices, table);
s1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, s1);
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(s0, s1);
prev1 = curr1;
sum = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum);
}
if ((src1_vec_ptr_end - ((HVX_Vector *) src1)) > 0) {
// handle the last vector
// see also:
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
bool should_fetch_src1 = !hexagon::is_addr_aligned(src1_vec_ptr);
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
HVX_Vector s0 = _DequantFunc(src0_ptr, qs_indices, scale_indices, table);
s1 = Q6_Vqf32_vadd_VsfVsf(kZeroV, s1);
HVX_Vector mpy0 = Q6_Vqf32_vmpy_Vqf32Vqf32(s0, s1);
prev1 = curr1;
sum = Q6_Vqf32_vadd_Vqf32Vqf32(mpy0, sum);
}
return _ReduceFunc(sum);
}
template <HVX_Vector (*_Func)(HVX_Vector, HVX_UVector *, HVX_Vector),
HVX_Vector (*_FuncScaleConvert)(float),
typename _TParam>

View File

@ -0,0 +1,304 @@
#pragma once
#include "hexagon_npu.h"
#include <hexagon_types.h>
#include <cstdint>
namespace hexagon::vec::quant {
template <typename _TStruct, size_t _Count, auto _MemberPtr> inline HVX_Vector load_into_vector(const _TStruct * src) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load");
return *reinterpret_cast<const HVX_UVector *>(&(src->*_MemberPtr));
}
template <typename _TStruct, size_t _Count> inline HVX_Vector load_struct_into_vector(const _TStruct * src) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load");
return *reinterpret_cast<const HVX_UVector *>(src);
}
template <typename _TBlock> inline HVX_Vector load_block_generic(const _TBlock & src) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock), "wrong block size/padding");
return load_into_vector<_TBlock, 1, &_TBlock::qs>(&src);
}
template <typename _TBlock> inline HVX_Vector make_scale_load_mask() {
static_assert(sizeof(_TBlock) < hexagon::kBytesPerVector, "wrong block size/padding");
static_assert(std::is_same_v<decltype(_TBlock::d), npu_device_fp16_t>,
"scale field d must be of type npu_device_fp16_t");
constexpr const size_t kBytesPerScale = QUANT_BLOCK_SIZE * sizeof(_TBlock::d);
const size_t qs_start_offset = offsetof(_TBlock, d);
hexagon::HVX_VectorAlias ret;
size_t base_i = qs_start_offset;
for (size_t ret_idx = 0; ret_idx < hexagon::kBytesPerVector; ++ret_idx) {
const auto offset = ret_idx % kBytesPerScale;
const auto i = base_i + (offset % sizeof(_TBlock::d));
ret.u8[ret_idx] = (i & 1) ? (i / 2 + 64) : (i / 2);
if (offset == kBytesPerScale - 1) {
base_i += sizeof(_TBlock);
}
}
return ret.v;
}
template <typename _TBlock> inline HVX_Vector make_qs_load_mask() {
static_assert(sizeof(_TBlock) < hexagon::kBytesPerVector, "wrong block size/padding");
const size_t qs_start_offset = offsetof(_TBlock, qs);
const size_t qs_end_offset = qs_start_offset + sizeof(_TBlock::qs);
hexagon::HVX_VectorAlias ret;
size_t ret_idx = 0;
for (size_t i = 0; i < hexagon::kBytesPerVector; ++i) {
auto offset = i % sizeof(_TBlock);
if (offset >= qs_start_offset && offset < qs_end_offset) {
ret.u8[ret_idx++] = (i & 1) ? (i / 2 + 64) : (i / 2);
}
}
return ret.v;
}
template <typename _TBlock> inline HVX_Vector load_dual_block_generic(const _TBlock * srcs, HVX_VectorPred mask) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs;
HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs);
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale);
return Q6_V_vmux_QVV(mask, blocks, block1);
}
template <typename _TBlock>
inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong block size/padding");
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 2>(srcs);
HVX_Vector block01 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0);
block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 2);
HVX_Vector scale01 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks, 0);
scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2);
if constexpr (sizeof(_TBlock) * 4 > hexagon::kBytesPerVector) {
block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 1);
block01 = Q6_Vb_vlut32or_VbVbVbI(block01, qs_indices, blocks, 3);
scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 1);
scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 3);
}
hexagon::HVX_Vector_x2 result;
result.val[0] = block01;
result.val[1] = scale01;
return result;
}
template <typename _TBlock>
inline hexagon::HVX_Vector_x3 load_qual_block_generic(const _TBlock * srcs,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 4, "wrong block size/padding");
hexagon::HVX_Vector_x3 result;
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 4>(srcs);
{
HVX_Vector block0123 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0);
block0123 = Q6_Vb_vlut32or_VbVbVbI(block0123, qs_indices, blocks, 1);
block0123 = Q6_Vb_vlut32or_VbVbVbI(block0123, qs_indices, blocks, 2);
block0123 = Q6_Vb_vlut32or_VbVbVbI(block0123, qs_indices, blocks, 3);
result.val[0] = block0123;
}
{
HVX_Vector blocks23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2);
HVX_Vector scale01 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks, 0);
scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2);
HVX_Vector scale23 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks23, 0);
scale23 = Q6_Vb_vlut32or_VbVbVbI(scale23, scale_indices, blocks23, 2);
result.val[1] = scale01;
result.val[2] = scale23;
}
return result;
}
template <typename _TBlock>
inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 6, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 6>(srcs);
hexagon::HVX_Vector_x5 result;
{
HVX_Vector block012345 = Q6_Vb_vlut32_VbVbI(qs_indices, blocks, 0);
block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 1);
block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 2);
block012345 = Q6_Vb_vlut32or_VbVbVbI(block012345, qs_indices, blocks, 3);
result.val[0] = block012345;
result.val[3] = Q6_V_vror_VR(block012345, kSizeOfQs * 4); // block45
}
{
HVX_Vector blocks23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2);
HVX_Vector blocks45 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 4);
HVX_Vector scale01 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks, 0);
scale01 = Q6_Vb_vlut32or_VbVbVbI(scale01, scale_indices, blocks, 2);
HVX_Vector scale23 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks23, 0);
scale23 = Q6_Vb_vlut32or_VbVbVbI(scale23, scale_indices, blocks23, 2);
HVX_Vector scale45 = Q6_Vb_vlut32_VbVbI(scale_indices, blocks45, 0);
scale45 = Q6_Vb_vlut32or_VbVbVbI(scale45, scale_indices, blocks45, 2);
result.val[1] = scale01;
result.val[2] = scale23;
result.val[4] = scale45;
}
return result;
}
inline HVX_Vector dequantize_vec_q40_qf16_2blocks(HVX_Vector qs, HVX_Vector scale01, HVX_Vector table) {
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2));
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
return Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(qp0), scale01);
}
inline HVX_VectorPair dequantize_vec_q40_qf32_2blocks(HVX_Vector qs, HVX_Vector scale01, HVX_Vector table) {
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2));
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_V_lo_W(qp0);
scale01 = Q6_Vh_vshuff_Vh(scale01);
q_lo = Q6_Vh_vshuff_Vh(q_lo); // TODO: avoid vshuff here
return Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01);
}
inline HVX_Vector_x2 dequantize_vec_q40_qf16_4blocks(HVX_Vector qs,
HVX_Vector scale01,
HVX_Vector scale23,
HVX_Vector table) {
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4));
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_V_lo_W(qp0);
q_hi = Q6_V_hi_W(qp0);
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scale01);
q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scale23);
hexagon::HVX_Vector_x2 result;
result.val[0] = q_lo;
result.val[1] = q_hi;
return result;
}
inline HVX_VectorPair_x2 dequantize_vec_q40_qf32_4blocks(HVX_Vector qs,
HVX_Vector scale01,
HVX_Vector scale23,
HVX_Vector table) {
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4));
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_V_lo_W(qp0);
q_hi = Q6_V_hi_W(qp0);
q_lo = Q6_Vh_vshuff_Vh(q_lo);
scale01 = Q6_Vh_vshuff_Vh(scale01);
q_hi = Q6_Vh_vshuff_Vh(q_hi);
scale23 = Q6_Vh_vshuff_Vh(scale23); // TODO: avoid vshuff here
hexagon::HVX_VectorPair_x2 result;
result.val[0] = Q6_Wqf32_vmpy_VhfVhf(q_lo, scale01);
result.val[1] = Q6_Wqf32_vmpy_VhfVhf(q_hi, scale23);
return result;
}
inline HVX_Vector load_dequant_vec_q40_qf32_1block(const npu_device_block_q4_0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table) {
// TODO: can we have a single-block version of load and dequantize?
auto qs = load_dual_block_generic(src, qs_indices, scale_indices);
return Q6_V_lo_W(dequantize_vec_q40_qf32_2blocks(qs.val[0], qs.val[1], table));
}
inline HVX_Vector load_dequant_vec_q40_qf16_2blocks(const npu_device_block_q4_0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table) {
auto qs = load_dual_block_generic(src, qs_indices, scale_indices);
return dequantize_vec_q40_qf16_2blocks(qs.val[0], qs.val[1], table);
}
inline HVX_VectorPair load_dequant_vec_q40_qf32_2blocks(const npu_device_block_q4_0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table) {
auto qs = load_dual_block_generic(src, qs_indices, scale_indices);
return dequantize_vec_q40_qf32_2blocks(qs.val[0], qs.val[1], table);
}
inline HVX_VectorPair_x2 load_dequant_vec_q40_qf32_4blocks(const npu_device_block_q4_0 * src,
const HVX_Vector qs_indices,
const HVX_Vector scale_indices,
const HVX_Vector table) {
auto qs = load_qual_block_generic(src, qs_indices, scale_indices);
return dequantize_vec_q40_qf32_4blocks(qs.val[0], qs.val[1], qs.val[2], table);
}
} // namespace hexagon::vec::quant

View File

@ -20,6 +20,7 @@ static hexagon::host_buffer_type * get_buffer_type_object(ggml_backend_buffer_ty
}
void backend_buffer_free_buffer(ggml_backend_buffer_t buffer) {
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_free_buffer", (void *) get_buffer_object(buffer));
delete get_buffer_object(buffer);
}
@ -39,6 +40,7 @@ ggml_status backend_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor
auto * buffer_obj = get_buffer_object(buffer);
GGML_ASSERT(buffer_obj != nullptr);
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_init_tensor", (void *) buffer_obj);
auto tensor_object = buffer_obj->init_tensor(tensor, device_object->get_device_handle());
if (!tensor_object) {
LOG_ERROR("Failed to init tensor\n");
@ -48,12 +50,23 @@ ggml_status backend_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor
return GGML_STATUS_SUCCESS;
}
void backend_buffer_memset_tensor(ggml_backend_buffer_t buffer,
ggml_tensor * tensor,
uint8_t value,
size_t offset,
size_t size) {
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_memset_tensor.size.%zu",
(void *) get_buffer_object(buffer), size);
memset((char *) tensor->data + offset, value, size);
}
void backend_buffer_set_tensor(ggml_backend_buffer_t buffer,
ggml_tensor * tensor,
const void * data,
size_t offset,
size_t size) {
GGML_UNUSED(buffer);
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_set_tensor.size.%zu",
(void *) get_buffer_object(buffer), size);
memcpy((char *) tensor->data + offset, data, size);
}
@ -62,23 +75,27 @@ void backend_buffer_get_tensor(ggml_backend_buffer_t buffer,
void * data,
size_t offset,
size_t size) {
GGML_UNUSED(buffer);
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_get_tensor", (void *) get_buffer_object(buffer));
memcpy(data, (const char *) tensor->data + offset, size);
}
bool backend_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
GGML_UNUSED(buffer);
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_cpy_tensor", (void *) get_buffer_object(buffer));
if (ggml_backend_buffer_is_host(src->buffer)) {
memcpy(dst->data, src->data, ggml_nbytes(src));
return true;
}
LOG_DEBUG("[hexagon-npu][%p]backend_buffer_cpy_tensor: copy from non-host buffer not supported\n",
(void *) get_buffer_object(buffer));
return false;
}
void backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
auto * buffer_obj = get_buffer_object(buffer);
GGML_ASSERT(buffer_obj != nullptr);
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_clear", (void *) buffer_obj);
memset(buffer_obj->get_buffer(), value, buffer_obj->get_size());
}
@ -94,7 +111,7 @@ constexpr const ggml_backend_buffer_i backend_buffer_interface = {
/* .free_buffer = */ backend_buffer_free_buffer,
/* .get_base = */ backend_buffer_get_base,
/* .init_tensor = */ backend_buffer_init_tensor,
/* .memset_tensor = */ nullptr,
/* .memset_tensor = */ backend_buffer_memset_tensor,
/* .set_tensor = */ backend_buffer_set_tensor,
/* .get_tensor = */ backend_buffer_get_tensor,
/* .cpy_tensor = */ backend_buffer_cpy_tensor,
@ -146,8 +163,8 @@ host_buffer::host_buffer(common::rpc_mem_ptr allocator, size_t size, uint32_t do
}
if (size > _allocator->get_max_alloc_size()) {
LOG_ERROR(
"[hexagon-npu]rpc memory size %zu exceeds max alloc size %zu\n", size, _allocator->get_max_alloc_size());
LOG_ERROR("[hexagon-npu]rpc memory size %zu exceeds max alloc size %zu\n", size,
_allocator->get_max_alloc_size());
return;
}
@ -161,8 +178,8 @@ host_buffer::host_buffer(common::rpc_mem_ptr allocator, size_t size, uint32_t do
}
host_buffer::~host_buffer() {
LOG_DEBUG(
"[hexagon-npu]destroy host_buffer(%p), size: %zu, domain_id: %d\n", (void *) _data, _size, (int) _domain_id);
LOG_DEBUG("[hexagon-npu]destroy host_buffer(%p), size: %zu, domain_id: %d\n", (void *) _data, _size,
(int) _domain_id);
_tensors.clear();
if (_buffer_fd != -1) {
auto ret = _allocator->fastrpc_munmap((int) _domain_id, _buffer_fd, nullptr, 0);
@ -194,17 +211,12 @@ std::shared_ptr<host_tensor> host_buffer::init_tensor(ggml_tensor * tensor, remo
return std::shared_ptr<host_tensor>();
}
LOG_DEBUG("[hexagon-npu]mmap rpc memory(%p), fd: %d, addr: %p, size: %zu\n",
(void *) _data,
_buffer_fd,
_data,
LOG_DEBUG("[hexagon-npu]mmap rpc memory(%p), fd: %d, addr: %p, size: %zu\n", (void *) _data, _buffer_fd, _data,
_size);
}
auto tensor_object = std::make_shared<host_tensor>(
tensor,
_buffer_fd,
(uint64_t) (reinterpret_cast<uint8_t *>(tensor->data) - reinterpret_cast<uint8_t *>(_data)),
tensor, _buffer_fd, (uint64_t) (reinterpret_cast<uint8_t *>(tensor->data) - reinterpret_cast<uint8_t *>(_data)),
device_handle);
if (!tensor_object->is_valid()) {
LOG_ERROR("[hexagon-npu]failed to init tensor, device handle: %p\n", (void *) device_handle);

View File

@ -1,12 +1,13 @@
#include <memory>
#include <string>
#include "buffer.hpp"
#include "common.hpp"
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "host_device.hpp"
#include "profiler.hpp"
#include <memory>
#include <string>
namespace {
@ -42,7 +43,9 @@ void backend_dev_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * tota
enum ggml_backend_dev_type backend_dev_get_type(ggml_backend_dev_t dev) {
GGML_UNUSED(dev);
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
// TODO: figure out why the GGML_BACKEND_DEVICE_TYPE_ACCEL type will miss some ops
return GGML_BACKEND_DEVICE_TYPE_IGPU;
}
void backend_dev_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
@ -62,6 +65,7 @@ ggml_backend_t backend_dev_init_backend(ggml_backend_dev_t dev, const char * par
return nullptr;
}
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_dev_init_backend", (void *) dev_obj);
return new hexagon::npu_backend(dev);
}
@ -71,8 +75,10 @@ ggml_backend_buffer_type_t backend_dev_get_buffer_type(ggml_backend_dev_t dev) {
return dev_obj->get_default_buffer_type(dev);
}
ggml_backend_buffer_t backend_dev_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size,
size_t max_tensor_size) {
ggml_backend_buffer_t backend_dev_buffer_from_host_ptr(ggml_backend_dev_t dev,
void * ptr,
size_t size,
size_t max_tensor_size) {
// TODO: should we use the device memory here?
GGML_UNUSED(dev);
GGML_UNUSED(max_tensor_size);
@ -86,6 +92,8 @@ bool backend_dev_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor *
auto * dev_obj = get_device_object(dev);
GGML_ASSERT(dev_obj != nullptr);
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_dev_supports_op", (void *) dev_obj);
return dev_obj->supports_op(op);
}
@ -96,6 +104,7 @@ bool backend_dev_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_
auto * dev_obj = get_device_object(dev);
GGML_ASSERT(dev_obj != nullptr);
return dev_obj->supports_buft(buft);
}
@ -106,6 +115,8 @@ bool backend_dev_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * o
auto * dev_obj = get_device_object(dev);
GGML_ASSERT(dev_obj != nullptr);
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_dev_offload_op", (void *) dev_obj);
return dev_obj->offload_op(op);
}

View File

@ -143,17 +143,18 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) {
}
if (op->op == GGML_OP_VIEW || op->op == GGML_OP_RESHAPE || op->op == GGML_OP_PERMUTE) {
LOG_DEBUG("[%s]view/reshape/permute op is always supported\n", get_name());
return true;
}
if (type_to_npu_type(op->type) == NPU_DATA_TYPE_COUNT) {
LOG_DEBUG("[%s]Unsupported op tensor type: %s\n", get_name(), ggml_type_name(op->type));
LOG_DEBUG("[%s][%s]Unsupported op tensor type: %s\n", get_name(), ggml_get_name(op), ggml_type_name(op->type));
return false;
}
auto npu_op = op_to_npu_op(op->op);
if (npu_op == NPU_OP_COUNT) {
LOG_DEBUG("[%s]Unsupported op: %s\n", get_name(), ggml_op_desc(op));
LOG_DEBUG("[%s][%s]Unsupported op: %s\n", get_name(), ggml_get_name(op), ggml_op_desc(op));
return false;
}
@ -206,14 +207,8 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) {
#ifndef NDEBUG
auto * src0_type = i ? ggml_type_name(op->src[0]->type) : "null";
auto * src1_type = (i > 1) ? ggml_type_name(op->src[1]->type) : "null";
LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n",
get_name(),
ggml_op_desc(op),
ggml_type_name(op->type),
src0_type,
src1_type,
ret,
supported);
LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", get_name(), ggml_op_desc(op),
ggml_type_name(op->type), src0_type, src1_type, ret, supported);
#endif
return false;
}
@ -258,8 +253,8 @@ bool npu_device::init_device_lib() {
}
if (err != AEE_SUCCESS) {
LOG_ERROR(
"[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err, device_lib_uri.c_str());
LOG_ERROR("[%s]Unable to open NPU device, err: 0x%x, uri %s\n", get_name(), err,
device_lib_uri.c_str());
_device_handle = 0;
return false;
}
@ -289,26 +284,16 @@ bool npu_device::supports_op(const ggml_tensor * op) {
if (op->op != GGML_OP_NONE && op->op != GGML_OP_VIEW && op->op != GGML_OP_RESHAPE &&
op->op != GGML_OP_PERMUTE) {
_supported_op++;
LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n",
get_name(),
ggml_op_desc(op),
ggml_get_name(op),
op_desc,
_supported_op.load(),
_unsupported_op.load());
LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op),
ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load());
}
return true;
}
_unsupported_op++;
LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n",
get_name(),
ggml_op_desc(op),
ggml_get_name(op),
op_desc,
_supported_op.load(),
_unsupported_op.load());
LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op),
ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load());
return false;
}
#else