feat: perf opt dma phase2 (#57)

* Add power management utilities to NPU device context and update DCVS settings

* Update DCVS settings in power_utils to use v3 API and enhance power management

* wip

* Enhance dequantization functions by adding load_dequant_table support and updating signatures for improved performance

* use lut

* wip

* fix test failure

* wip

* Refactor load_qual_block_generic to improve block handling and optimize vector operations

* Enhance load_dual_block_generic and load_qual_block_generic to accept a mask parameter for improved block handling

* Refactor flash_attn_impl to optimize mask l2 prefetch

* wip

* wip

* wip

* wip

* add log

* link against shared libraries instead of static ones

* fix swiglu

* wip

* refactor expf_fix to handle overflow for different data types

* enhance is_glu_op_supported to validate shapes for multiple sources

* wip

* refactor logging macros to use hexagon namespace and improve formatting

* fix printf format error

* wip

* refactor: update static_assert messages for block size validation and add HVX_VectorPred_x3 type alias

* rename

* feat: enhance fa with mask

* wip

* wip

* refactor: replace instances of Q6_V_vzero() with kZeroV for consistency

* wip

* wip

* wip

* fix: improve address alignment check in HVX_Vector handling

* refactor: streamline vector dot product implementations for improved readability

* refactor: q4k add hvx intrinsic impl

* refactor: enhance dequantize_row_q4_K for clarity and performance

* refactor: optimize scale mask usage in dequantization functions for improved performance

* refactor: optimize dequantize_row_q4_K for intrinsic usage and performance improvements

* refactor: move GLU operation implementation into separated file

* sync after swiglu

* wip

* wip

* wip

* feat: increase prc main thread stack size

* fix: replace hardcoded stack size with NPU_THREAD_STACK_SIZE constant

* wip

* feat: add optimized vector operations for exponential and division with overflow handling

* wip

* feat: refactor exponential function to handle overflow and underflow with improved logic

* wip

* wip

* feat: add vector loading and scaling functions for improved performance in block processing

* wip

* feat: optimize block loading by refactoring scale index handling for improved performance

* use Q6_Vb_vlut32_VbVbR_nomatch instead

* feat: enhance scale loading by adding static assertion and restructuring block handling

* wip

* feat: refactor vec_dot_product_mixed_impl for improved clarity and performance

* wip

* feat: simplify vector loading functions and improve alignment handling

* wip

* feat: enhance scale loading mask with quantization block size validation

* wip

* feat: implement make_scale_load_mask function and refactor vector handling in vec_ops

* feat: enhance load_dual_block_generic to include scale indices for improved vector loading

* revert q8 dequant

* wip

* feat: optimize dequantization functions by removing unnecessary masking and updating lookup methods

* wip

* wip

* add qurt_mutex

* Add DMA transfer class and integrate into thread pool

* Enhance DMA transfer functionality by adding support for multiple descriptors and initiating transfers in parallel

* fix dma crash

* fix failed unit tests

* wip

* use alignas

* Improve DMA transfer error handling and update descriptor completion check

* Fix VTCM cache size calculation in element-wise operations

* Add cache clean operations before DMA transfers in element-wise operations

* reduce cache clean operations

* Refactor DMA transfer functions to support 1D operations and rename for clarity

* Enhance DMA transfer functionality by adding 2D submission support and improving descriptor initialization

* Update read buffer method to support forced invalidation and remove unnecessary invalidation calls in element-wise operations

* wip

* Improve DMA transfer handling in mul_mat_gemv_impl by replacing memcpy with initiate_dma_row_transfer and adding wait_for_dma logic

* fix 2d dma

* feat: add DMA plane cache

* rename

* wip

* use memcpy for debug

* fix cache plane calc

* refactor: remove debug logging from mul_mat_impl and optimize cache handling

* rename

* fix 2d dma type

* refactor: enhance DMA transfer handling in mul_mat_gemv_impl and wait functions

* refactor: optimize DMA transfer handling in mul_mat_gemv_impl and wait functions

* wip

* wip

* move op impl into sub dir

* add log

* fix: correct pointer usage in mul_mat_gemv_impl for next plane access

* fix: improve DMA transfer error handling in mul_mat_impl and mul_mat_gemv_impl

* fix: fix crash by using the entire row bytes

* wip

* wip

* fix: prevent parallelization for scalar src1 in is_mul_mat_supported

* fix: add dimension checks for 2D DMA transfers and fallback to 1D if necessary

* wip

* fix: enable thread barrier for mul multiplication operations

* feat: add synchronization checks for tensor operations and update related functions

* wip

* fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations

* Revert "fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations"

This reverts commit af3441e67e706b2e5122369dc160353796867dd3.

* wip

* wip

* add comment

* fix: improve DMA transfer handling in mul_mat_gemv_impl for quantized source tensors

* add log

* try fix mulmat gemv

* wip

* fix: enhance DMA transfer handling in mul_mat_gemv_impl for quantized source tensors

* fix: optimize cache offset calculation and remove redundant swap in mul_mat_gemv_impl

* fix: refactor DMA transfer handling in mul_mat_gemv_impl for improved clarity and maintainability

* wip

* wip

* wip

* fix: enhance mul_mat_impl for improved cache handling and clarity

* fix: refactor tensor unflattening and DMA transfer initialization for improved clarity and type safety

* fix: improve cache handling of quant

* wip

* fix: improve cache handling in mul_mat_impl and mul_mat_gemv_impl for better memory efficiency

* rename

* add load_hexa_block_generic

* wip

* extract dequant block into separated function

* refactor: enhance dequantization functions with table parameter

* fix load_dual_block_generic

* refactor: rename dequantization functions for clarity and enhance block handling

* refactor: simplify dequantization logic by consolidating block handling and removing unused parameters

* wip

* wip
This commit is contained in:
nullname 2025-10-05 22:56:08 +08:00 committed by GitHub
parent 3994a9b7df
commit e1727af06c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 313 additions and 223 deletions

View File

@ -12,7 +12,7 @@ dma_transfer::dma_transfer() {
dma_desc_set_next(_dma_1d_desc0, 0); dma_desc_set_next(_dma_1d_desc0, 0);
dma_desc_set_dstate(_dma_1d_desc0, DESC_DSTATE_INCOMPLETE); dma_desc_set_dstate(_dma_1d_desc0, DESC_DSTATE_INCOMPLETE);
dma_desc_set_desctype(_dma_1d_desc0, DMA_DESC_TYPE_1D); dma_desc_set_desctype(_dma_1d_desc0, DMA_DESC_TYPE_1D);
dma_desc_set_order(_dma_1d_desc0, DESC_ORDER_ORDER); dma_desc_set_order(_dma_1d_desc0, DESC_ORDER_NOORDER);
dma_desc_set_bypasssrc(_dma_1d_desc0, DESC_BYPASS_ON); // for dram dma_desc_set_bypasssrc(_dma_1d_desc0, DESC_BYPASS_ON); // for dram
dma_desc_set_bypassdst(_dma_1d_desc0, DESC_BYPASS_OFF); // for vtcm dma_desc_set_bypassdst(_dma_1d_desc0, DESC_BYPASS_OFF); // for vtcm
dma_desc_set_length(_dma_1d_desc0, 0); dma_desc_set_length(_dma_1d_desc0, 0);
@ -20,7 +20,7 @@ dma_transfer::dma_transfer() {
dma_desc_set_next(_dma_1d_desc1, 0); dma_desc_set_next(_dma_1d_desc1, 0);
dma_desc_set_dstate(_dma_1d_desc1, DESC_DSTATE_INCOMPLETE); dma_desc_set_dstate(_dma_1d_desc1, DESC_DSTATE_INCOMPLETE);
dma_desc_set_desctype(_dma_1d_desc1, DMA_DESC_TYPE_1D); dma_desc_set_desctype(_dma_1d_desc1, DMA_DESC_TYPE_1D);
dma_desc_set_order(_dma_1d_desc1, DESC_ORDER_ORDER); dma_desc_set_order(_dma_1d_desc1, DESC_ORDER_NOORDER);
dma_desc_set_bypasssrc(_dma_1d_desc1, DESC_BYPASS_ON); // for dram dma_desc_set_bypasssrc(_dma_1d_desc1, DESC_BYPASS_ON); // for dram
dma_desc_set_bypassdst(_dma_1d_desc1, DESC_BYPASS_OFF); // for vtcm dma_desc_set_bypassdst(_dma_1d_desc1, DESC_BYPASS_OFF); // for vtcm
dma_desc_set_length(_dma_1d_desc1, 0); dma_desc_set_length(_dma_1d_desc1, 0);
@ -28,7 +28,7 @@ dma_transfer::dma_transfer() {
dma_desc_set_next(_dma_2d_desc0, 0); dma_desc_set_next(_dma_2d_desc0, 0);
dma_desc_set_dstate(_dma_2d_desc0, DESC_DSTATE_INCOMPLETE); dma_desc_set_dstate(_dma_2d_desc0, DESC_DSTATE_INCOMPLETE);
dma_desc_set_desctype(_dma_2d_desc0, DMA_DESC_TYPE_2D); dma_desc_set_desctype(_dma_2d_desc0, DMA_DESC_TYPE_2D);
dma_desc_set_order(_dma_2d_desc0, DESC_ORDER_ORDER); dma_desc_set_order(_dma_2d_desc0, DESC_ORDER_NOORDER);
dma_desc_set_bypasssrc(_dma_2d_desc0, DESC_BYPASS_ON); // for dram dma_desc_set_bypasssrc(_dma_2d_desc0, DESC_BYPASS_ON); // for dram
dma_desc_set_bypassdst(_dma_2d_desc0, DESC_BYPASS_OFF); // for vtcm dma_desc_set_bypassdst(_dma_2d_desc0, DESC_BYPASS_OFF); // for vtcm
dma_desc_set_cachealloc(_dma_2d_desc0, DESC_CACHEALLOC_NONE); dma_desc_set_cachealloc(_dma_2d_desc0, DESC_CACHEALLOC_NONE);

View File

@ -20,6 +20,12 @@ template <> struct convert_vector<float> {
static float convert(HVX_Vector vec) { return hexagon::get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec)); } static float convert(HVX_Vector vec) { return hexagon::get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec)); }
}; };
inline std::pair<int64_t, int64_t> unflatten_i3_i2(int64_t idx, const hexagon::tensor * t) {
const auto i3 = idx / t->get_ne(2);
const auto i2 = idx - i3 * t->get_ne(2);
return { i3, i2 };
}
template <> struct convert_vector<npu_device_fp16_t> { template <> struct convert_vector<npu_device_fp16_t> {
static float convert(HVX_Vector vec) { static float convert(HVX_Vector vec) {
HVX_Vector vect = Q6_Vhf_equals_Vqf16(vec); HVX_Vector vect = Q6_Vhf_equals_Vqf16(vec);
@ -28,6 +34,27 @@ template <> struct convert_vector<npu_device_fp16_t> {
} }
}; };
template <bool _IsQuantized>
inline bool init_dma_transfer(hexagon::compute_params * params,
const uint8_t * src,
uint8_t * dst,
size_t width,
size_t height,
size_t src_stride,
size_t dst_stride) {
if constexpr (_IsQuantized) {
if (!params->initiate_dma_row_transfer(src, dst, src_stride * height)) {
return false;
}
} else {
if (!params->initiate_dma_plane_transfer(src, dst, width, height, src_stride, dst_stride)) {
return false;
}
}
return true;
}
template <auto _DotFunc> template <auto _DotFunc>
inline void batched_row_dot(const uint8_t * src0_plane, inline void batched_row_dot(const uint8_t * src0_plane,
const size_t src0_ne0, const size_t src0_ne0,
@ -75,9 +102,10 @@ inline void mul_mat_impl(hexagon::tensor * src0,
hexagon::tensor * src1, hexagon::tensor * src1,
hexagon::tensor * dst, hexagon::tensor * dst,
hexagon::compute_params * params) { hexagon::compute_params * params) {
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1; using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0); const auto src0_actual_row_stride = hexagon::get_dequantized_row_size(src0);
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table; auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
if (_IsSrcQuantized && dequantize_row_func == nullptr) { if (_IsSrcQuantized && dequantize_row_func == nullptr) {
@ -113,66 +141,70 @@ inline void mul_mat_impl(hexagon::tensor * src0,
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
// cache the src0 plane in VTCM // cache the src0 plane in VTCM
size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first; const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0));
size_t src0_plane_cache_size = 0; const size_t src1_actual_row_stride = hexagon::get_aligned_size(src1->get_nb(1));
uint8_t * src0_plane_read_cache_ptr = nullptr;
uint8_t * src0_plane_write_cache_ptr = nullptr; // TODO: figure out why we have to add padding after src0 plane cache
const uint8_t * last_write_cached_plane_ptr = nullptr; const size_t src0_plane_slice_row_count =
const uint8_t * last_read_cached_plane_ptr = nullptr; std::min<size_t>((params->get_vtcm_quota_size() - src1_actual_row_stride) / (src0_actual_row_stride * 2),
if constexpr (_IsSrcQuantized) { start_end_element.second - start_end_element.first);
src0_plane_slice_row_count = uint8_t * src0_plane_read_cache_ptr = nullptr;
std::min(params->get_vtcm_quota_size() / src0_actual_row_size, src0_plane_slice_row_count); uint8_t * src0_plane_write_cache_ptr = nullptr;
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count; size_t src0_plane_write_cache_offset = 0;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size); const uint8_t * last_write_cached_plane_ptr = nullptr;
if (src0_plane_read_cache_ptr == nullptr) { const uint8_t * last_read_cached_plane_ptr = nullptr;
{
const size_t src0_plane_cache_size = src0_actual_row_stride * src0_plane_slice_row_count;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2);
if (!src0_plane_read_cache_ptr) {
DEVICE_LOG_ERROR( DEVICE_LOG_ERROR(
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, " "mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
"src0_actual_row_size: %zu, will fallback to mem cache\n", "src0_actual_row_stride: %zu, will fallback to mem cache\n",
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size); src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_stride);
return;
}
} else {
src0_plane_slice_row_count =
std::min(params->get_vtcm_quota_size() / (src0_actual_row_size * 2), src0_plane_slice_row_count);
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2);
if (src0_plane_read_cache_ptr == nullptr) {
DEVICE_LOG_ERROR(
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
"src0_actual_row_size: %zu, will fallback to mem cache\n",
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size);
return; return;
} }
src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size; src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size;
if constexpr (_IsSrcQuantized) {
src0_plane_write_cache_offset =
src0_plane_cache_size - size_t(src0->get_nb(1) * src0_plane_slice_row_count);
}
const auto i3 = start_end_plane.first / dst->get_ne(2); DEVICE_LOG_DEBUG(
const auto i2 = start_end_plane.first - i3 * dst->get_ne(2); "[%d]mul_mat_impl, src0_actual_row_stride:%zu, valid_src0_row_bytes:%zu, src_nb0:%zu, "
"slice_row_count:%zu, write_cache_offset: %zu, "
"total_planes:%lld, planes:[%d,%d), rows:[%d,%d), elems:[%d,%d), is_quant:%d, "
"vtcm_mem:%p(%zu)\n",
(int) params->get_thread_index(), src0_actual_row_stride, valid_src0_row_bytes, (size_t) src0->get_nb(1),
src0_plane_slice_row_count, src0_plane_write_cache_offset, total_planes, (int) start_end_plane.first,
(int) start_end_plane.second, (int) start_end_row.first, (int) start_end_row.second,
(int) start_end_element.first, (int) start_end_element.second, _IsSrcQuantized,
(void *) src0_plane_read_cache_ptr, params->get_vtcm_quota_size());
}
{
const auto [i3, i2] = unflatten_i3_i2(start_end_plane.first, dst);
const uint8_t * src0_plane = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + const uint8_t * src0_plane = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) +
start_end_element.first * src0->get_nb(1); start_end_element.first * src0->get_nb(1);
const int64_t next_row_count = const size_t next_row_count =
std::min<int64_t>(src0_plane_slice_row_count, std::min<size_t>(src0_plane_slice_row_count,
start_end_element.second - start_end_element.first); // number of rows in this slice start_end_element.second - start_end_element.first); // number of rows in this slice
if (!params->initiate_dma_plane_transfer(src0_plane, src0_plane_write_cache_ptr, if (!init_dma_transfer<_IsSrcQuantized>(
src0_actual_row_size, // TODO: reduce to aligned valid_row0_bytes? params, src0_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset, valid_src0_row_bytes,
next_row_count, src0_actual_row_size, src0_actual_row_size)) { next_row_count, src0->get_nb(1), src0->get_nb(1))) {
DEVICE_LOG_ERROR("mul_mat_impl: failed to continue dma transfer for src0 plane\n"); DEVICE_LOG_ERROR("mul_mat_impl: failed to continue dma transfer for src0 plane, is_quant: %d\n",
(int) _IsSrcQuantized);
return; return;
} }
DEVICE_LOG_DEBUG("mul_mat_impl: [i2,i3]:[%d,%d], src0_plane:%p, row_count:%zu\n", (int) i2, (int) i3,
(void *) src0_plane, next_row_count);
last_write_cached_plane_ptr = src0_plane; last_write_cached_plane_ptr = src0_plane;
} }
DEVICE_LOG_DEBUG(
"[%d]mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, total_planes: %lld, "
"start_end_plane: "
"[%d,%d), start_end_row: [%d,%d), start_end_element: [%d,%d), is_quantized: %d, vtcm_mem: %p(%zu)\n",
(int) params->get_thread_index(), src0_actual_row_size, src0_plane_slice_row_count, total_planes,
(int) start_end_plane.first, (int) start_end_plane.second, (int) start_end_row.first,
(int) start_end_row.second, (int) start_end_element.first, (int) start_end_element.second, _IsSrcQuantized,
(void *) src0_plane_read_cache_ptr, params->get_vtcm_quota_size());
const size_t valid_row1_bytes = const size_t valid_row1_bytes =
src0->get_ne(0) * sizeof(data_type1); // src0 and src1 should have the same element count in the 1st dimension src0->get_ne(0) * sizeof(data_type1); // src0 and src1 should have the same element count in the 1st dimension
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat); DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
@ -184,11 +216,10 @@ inline void mul_mat_impl(hexagon::tensor * src0,
return; return;
} }
auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector(); const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
const uint8_t * src1_ptr = src1->get_read_buffer(); const uint8_t * src1_ptr = src1->get_read_buffer();
for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) { for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) {
const auto i3 = ip / dst->get_ne(2); const auto [i3, i2] = unflatten_i3_i2(ip, dst);
const auto i2 = ip - i3 * dst->get_ne(2);
const auto * src1_plane = src1_ptr + i3 * src1->get_nb(3) + i2 * src1->get_nb(2); const auto * src1_plane = src1_ptr + i3 * src1->get_nb(3) + i2 * src1->get_nb(2);
auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2); auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2);
const uint8_t * src0_plane_base = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2); const uint8_t * src0_plane_base = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2);
@ -198,57 +229,37 @@ inline void mul_mat_impl(hexagon::tensor * src0,
const int64_t actual_row_count = const int64_t actual_row_count =
std::min<int64_t>(src0_plane_slice_row_count, std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - col_idx); // number of rows in this slice start_end_element.second - col_idx); // number of rows in this slice
if constexpr (_IsSrcQuantized) {
if (last_write_cached_plane_ptr != src0_plane) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
for (int64_t ir = 0; ir < actual_row_count; ir++) {
auto * src0_row = src0_plane + ir * src0->get_nb(1);
if (ir + 1 < actual_row_count) {
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
}
auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_size;
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
src0->get_ne(0), dequant_table);
}
last_write_cached_plane_ptr = src0_plane;
}
} else {
if (last_read_cached_plane_ptr != src0_plane) {
std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr);
last_read_cached_plane_ptr = src0_plane;
params->wait_for_dma();
}
{
const uint8_t * src0_next_plane = last_write_cached_plane_ptr; const uint8_t * src0_next_plane = last_write_cached_plane_ptr;
int64_t next_row_count = 0; size_t next_row_count = 0;
if (col_idx + src0_plane_slice_row_count < start_end_element.second) { if (col_idx + src0_plane_slice_row_count < start_end_element.second) {
const auto next_col_idx = col_idx + src0_plane_slice_row_count; const auto next_col_idx = col_idx + src0_plane_slice_row_count;
src0_next_plane = src0_plane_base + next_col_idx * src0_actual_row_size; src0_next_plane = src0_plane_base + next_col_idx * src0->get_nb(1);
next_row_count = next_row_count =
std::min<int64_t>(src0_plane_slice_row_count, std::min<size_t>(src0_plane_slice_row_count,
start_end_element.second - next_col_idx); // number of rows in this slice start_end_element.second - next_col_idx); // number of rows in this slice
} else if (ip + 1 < start_end_plane.second) { } else if (ip + 1 < start_end_plane.second) {
// prefetch the next plane's first slice // prefetch the next plane's first slice
const auto ip_next = ip + 1; const auto [i3_next, i2_next] = unflatten_i3_i2(ip + 1, dst);
const auto i3_next = ip_next / dst->get_ne(2);
const auto i2_next = ip_next - i3_next * dst->get_ne(2);
const uint8_t * src0_next_plane_base = const uint8_t * src0_next_plane_base =
src0_ptr + i3_next / r03 * src0->get_nb(3) + i2_next / r02 * src0->get_nb(2); src0_ptr + i3_next / r03 * src0->get_nb(3) + i2_next / r02 * src0->get_nb(2);
src0_next_plane = src0_next_plane_base + start_end_element.first * src0_actual_row_size; src0_next_plane = src0_next_plane_base + start_end_element.first * src0->get_nb(1);
next_row_count = std::min<int64_t>( next_row_count = std::min<size_t>(
src0_plane_slice_row_count, src0_plane_slice_row_count,
start_end_element.second - start_end_element.first); // number of rows in this slice start_end_element.second - start_end_element.first); // number of rows in this slice
} }
if (last_read_cached_plane_ptr != src0_plane) {
std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr);
params->wait_for_dma();
}
if (last_write_cached_plane_ptr != src0_next_plane) { if (last_write_cached_plane_ptr != src0_next_plane) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dma); DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, dma);
if (!params->initiate_dma_plane_transfer( if (!init_dma_transfer<_IsSrcQuantized>(
src0_next_plane, src0_plane_write_cache_ptr, params, src0_next_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset,
src0_actual_row_size, // TODO: reduce to aligned valid_row0_bytes? valid_src0_row_bytes, next_row_count, src0->get_nb(1), src0->get_nb(1))) {
next_row_count, src0_actual_row_size, src0_actual_row_size)) {
DEVICE_LOG_ERROR("mul_mat_impl: failed to continue dma transfer for src0 plane\n"); DEVICE_LOG_ERROR("mul_mat_impl: failed to continue dma transfer for src0 plane\n");
return; return;
} }
@ -257,15 +268,30 @@ inline void mul_mat_impl(hexagon::tensor * src0,
} }
} }
if constexpr (_IsSrcQuantized) {
if (last_read_cached_plane_ptr != src0_plane) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
const uint8_t * src0_quant_plane = src0_plane_read_cache_ptr + src0_plane_write_cache_offset;
for (int64_t ir = 0; ir < actual_row_count; ir++) {
auto * src0_row = src0_quant_plane + ir * src0->get_nb(1);
auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_stride;
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
src0->get_ne(0), dequant_table);
}
}
}
last_read_cached_plane_ptr = src0_plane;
if (start_end_row.second > start_end_row.first) { if (start_end_row.second > start_end_row.first) {
hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_row1_bytes); hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_row1_bytes);
} }
for (int64_t i1 = start_end_row.first; i1 < start_end_row.second; i1++) { for (int64_t i1 = start_end_row.first; i1 < start_end_row.second; i1++) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot); DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dot);
auto * src1_row = src1_plane + i1 * src1->get_nb(1); auto * src1_row = src1_plane + i1 * src1->get_nb(1);
auto * dst_row = reinterpret_cast<float *>(dst_plane + i1 * dst->get_nb(1)) + col_idx; auto * dst_row = reinterpret_cast<float *>(dst_plane + i1 * dst->get_nb(1)) + col_idx;
batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_size, src1_row, batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_stride, src1_row,
src1->get_nb(1), dst_row, actual_row_count, src1->get_nb(1), dst_row, actual_row_count,
(ip + 1 < start_end_plane.second) ? valid_row1_bytes : 0); (ip + 1 < start_end_plane.second) ? valid_row1_bytes : 0);
} }
@ -283,7 +309,7 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0; using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1; using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0); const auto src0_actual_row_stride = hexagon::get_dequantized_row_size(src0);
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table; auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
if (_IsSrcQuantized && dequantize_row_func == nullptr) { if (_IsSrcQuantized && dequantize_row_func == nullptr) {
@ -309,53 +335,46 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
return; return;
} }
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0); const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0));
// cache the src0 plane in VTCM // cache the src0 plane in VTCM
size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first; const size_t src1_actual_row_stride = hexagon::get_aligned_size(src1->get_nb(1));
size_t src0_plane_cache_size = 0; const size_t src0_plane_slice_row_count =
uint8_t * src0_plane_read_cache_ptr = nullptr; std::min<size_t>((params->get_vtcm_quota_size() - src1_actual_row_stride) / (src0_actual_row_stride * 2),
uint8_t * src0_plane_write_cache_ptr = nullptr; start_end_element.second - start_end_element.first);
const auto src1_actual_row_size = hexagon::get_aligned_size(src1->get_nb(1));
uint8_t * src1_row_cache_ptr = nullptr; uint8_t * src0_plane_read_cache_ptr = nullptr;
if constexpr (_IsSrcQuantized) { uint8_t * src0_plane_write_cache_ptr = nullptr;
src0_plane_slice_row_count = std::min( size_t src0_plane_write_cache_offset = 0;
(params->get_vtcm_quota_size() - src1_actual_row_size) / src0_actual_row_size, src0_plane_slice_row_count); uint8_t * src1_row_cache_ptr = nullptr;
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size + src1_actual_row_size); DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
if (src0_plane_read_cache_ptr == nullptr) { {
const size_t src0_plane_cache_size = src0_actual_row_stride * src0_plane_slice_row_count;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_actual_row_stride);
if (!src0_plane_read_cache_ptr) {
DEVICE_LOG_ERROR( DEVICE_LOG_ERROR(
"mul_mat_gemv_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, " "mul_mat_gemv_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
"src0_actual_row_size: %zu, will fallback to mem cache\n", "src0_actual_row_stride: %zu, will fallback to mem cache\n",
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size); src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_stride);
return;
}
src1_row_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size;
} else {
src0_plane_slice_row_count =
std::min((params->get_vtcm_quota_size() - src1_actual_row_size) / (src0_actual_row_size * 2),
src0_plane_slice_row_count);
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_actual_row_size);
if (src0_plane_read_cache_ptr == nullptr) {
DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to get VTCM cache for src1, size: %zu\n", src1_actual_row_size);
return; return;
} }
src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size; src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size;
src1_row_cache_ptr = src0_plane_write_cache_ptr + src0_plane_cache_size; src1_row_cache_ptr = src0_plane_write_cache_ptr + src0_plane_cache_size;
if constexpr (_IsSrcQuantized) {
src0_plane_write_cache_offset = src0_plane_cache_size - (src0->get_nb(1) * src0_plane_slice_row_count);
}
DEVICE_LOG_DEBUG(
"mul_mat_gemv_impl: src0_actual_row_stride: %zu, src0_plane_slice_row_count: %zu, "
"src0_plane_write_cache_offset: %zu, src0.nb[1]: %d, is_quantized: %d, vtcm_mem: %p(%zu)\n",
src0_actual_row_stride, src0_plane_slice_row_count, src0_plane_write_cache_offset, int(src0->get_nb(1)),
_IsSrcQuantized, (void *) src0_plane_read_cache_ptr, src0_plane_cache_size);
} }
DEVICE_LOG_DEBUG(
"mul_mat_gemv_impl: src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
"%p(%zu)\n",
src0_actual_row_size, src0_plane_slice_row_count, _IsSrcQuantized, (void *) src0_plane_read_cache_ptr,
src0_plane_cache_size);
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
uint8_t * dst_ptr = dst->get_write_buffer(); uint8_t * dst_ptr = dst->get_write_buffer();
if (!dst_ptr) { if (!dst_ptr) {
DEVICE_LOG_ERROR("mul_mat_gemv_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst, DEVICE_LOG_ERROR("mul_mat_gemv_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst,
@ -372,66 +391,62 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0,
return; return;
} }
if constexpr (!_IsSrcQuantized) { const uint8_t * src0_plane = src0_ptr + start_end_element.first * src0->get_nb(1);
const uint8_t * src0_plane = src0_ptr + start_end_element.first * src0_actual_row_size; const int64_t next_row_count =
const int64_t next_row_count = std::min<int64_t>(src0_plane_slice_row_count,
std::min<int64_t>(src0_plane_slice_row_count, start_end_element.second - start_end_element.first); // number of rows in this slice
start_end_element.second - start_end_element.first); // number of rows in this slice params->wait_for_dma();
params->wait_for_dma();
if (!params->initiate_dma_plane_transfer(src0_plane, src0_plane_write_cache_ptr, valid_row0_bytes, if (!init_dma_transfer<_IsSrcQuantized>(
next_row_count, src0_actual_row_size, src0_actual_row_size)) { params, src0_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset, valid_src0_row_bytes,
DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to initiate dma transfer for src0 plane\n"); next_row_count, src0->get_nb(1), src0->get_nb(1))) {
return; DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to initiate dma plane transfer for src0 plane, is_quant: %d\n",
} (int) _IsSrcQuantized);
} else { return;
params->wait_for_dma();
} }
} }
{ {
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second; for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
col_idx += src0_plane_slice_row_count) { col_idx += src0_plane_slice_row_count) {
const uint8_t * src0_plane = src0_ptr + col_idx * src0->get_nb(1); const int64_t actual_row_count =
const int64_t actual_row_count =
std::min<int64_t>(src0_plane_slice_row_count, std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - col_idx); // number of rows in this slice start_end_element.second - col_idx); // number of rows in this slice
const auto next_col_idx = col_idx + src0_plane_slice_row_count;
std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr);
params->wait_for_dma();
if (next_col_idx < start_end_element.second) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, dma);
const uint8_t * src0_next_plane = src0_ptr + next_col_idx * src0->get_nb(1);
const int64_t next_row_count =
std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - next_col_idx); // number of rows in this slice
if (!init_dma_transfer<_IsSrcQuantized>(
params, src0_next_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset,
valid_src0_row_bytes, next_row_count, src0->get_nb(1), src0->get_nb(1))) {
DEVICE_LOG_ERROR(
"mul_mat_gemv_impl: failed to continue dma plane transfer for src0 plane, is_quant: %d\n",
(int) _IsSrcQuantized);
return;
}
}
if constexpr (_IsSrcQuantized) { if constexpr (_IsSrcQuantized) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant); DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
const uint8_t * src0_quant_plane = src0_plane_read_cache_ptr + src0_plane_write_cache_offset;
for (int64_t ir = 0; ir < actual_row_count; ir++) { for (int64_t ir = 0; ir < actual_row_count; ir++) {
auto * src0_row = src0_plane + ir * src0->get_nb(1); auto * src0_row = src0_quant_plane + ir * src0->get_nb(1);
if (ir + 1 < actual_row_count) { auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_stride;
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
}
auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_size;
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr), dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
src0->get_ne(0), dequant_table); src0->get_ne(0), dequant_table);
} }
} else {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dma);
std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr);
params->wait_for_dma();
const auto next_col_idx = col_idx + src0_plane_slice_row_count;
if (next_col_idx < start_end_element.second) {
const uint8_t * src0_next_plane = src0_ptr + next_col_idx * src0_actual_row_size;
const int64_t next_row_count =
std::min<int64_t>(src0_plane_slice_row_count,
start_end_element.second - next_col_idx); // number of rows in this slice
if (!params->initiate_dma_plane_transfer(src0_next_plane, src0_plane_write_cache_ptr,
valid_row0_bytes, next_row_count, src0_actual_row_size,
src0_actual_row_size)) {
DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to continue dma transfer for src0 plane\n");
return;
}
}
} }
{ {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot); DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dot);
auto * dst_row = reinterpret_cast<float *>(dst_ptr) + col_idx; auto * dst_row = reinterpret_cast<float *>(dst_ptr) + col_idx;
batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_size, batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_stride,
src1_row_cache_ptr, src1->get_nb(1), dst_row, actual_row_count, 0); src1_row_cache_ptr, src1->get_nb(1), dst_row, actual_row_count, 0);
} }
} }

View File

@ -44,7 +44,7 @@ template <typename _TBlock> inline HVX_Vector load_block_generic(const _TBlock &
} }
template <typename _TBlock> inline HVX_Vector make_scale_load_mask() { template <typename _TBlock> inline HVX_Vector make_scale_load_mask() {
static_assert(sizeof(_TBlock) < 32, "wrong block size/padding"); static_assert(sizeof(_TBlock) + sizeof(npu_device_fp16_t) < 32, "wrong block size/padding");
static_assert(sizeof(_TBlock::qs) == 16 || sizeof(_TBlock::qs) == 32, "wrong quantization block size"); static_assert(sizeof(_TBlock::qs) == 16 || sizeof(_TBlock::qs) == 32, "wrong quantization block size");
constexpr const size_t kScaleBlockSize = QUANT_BLOCK_SIZE * sizeof(hexagon::dequant_output_type); constexpr const size_t kScaleBlockSize = QUANT_BLOCK_SIZE * sizeof(hexagon::dequant_output_type);
@ -83,12 +83,13 @@ inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs,
hexagon::HVX_Vector_x2 result; hexagon::HVX_Vector_x2 result;
HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs); const HVX_Vector blocks = load_struct_into_vector<_TBlock, 2>(srcs);
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale); HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale);
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2);
HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks); HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks);
result.val[0] = Q6_V_vmux_QVV(mask, blocks, block1); result.val[0] = Q6_V_vmux_QVV(mask, block0, block1);
result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0); result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0);
return result; return result;
@ -143,6 +144,50 @@ inline hexagon::HVX_Vector_x3 load_qual_block_generic(const _TBlock *
return result; return result;
} }
template <typename _TBlock>
inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs,
const hexagon::HVX_VectorPred_x3 mask,
const HVX_Vector scale_indices) {
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 6, "wrong block size/padding");
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs;
const HVX_Vector blocks = load_struct_into_vector<_TBlock, 6>(srcs);
hexagon::HVX_Vector_x5 result;
{
HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale);
HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2);
HVX_Vector block2 = Q6_V_vror_VR(blocks, kSizeOfScale * 3);
HVX_Vector block3 = Q6_V_vror_VR(blocks, kSizeOfScale * 4);
HVX_Vector block4 = Q6_V_vror_VR(blocks, kSizeOfScale + sizeof(_TBlock) * 4);
HVX_Vector block5 = Q6_V_vror_VR(blocks, kSizeOfScale * 2 + sizeof(_TBlock) * 4);
HVX_Vector block01 = Q6_V_vmux_QVV(mask.val[0], block0, block1);
HVX_Vector block23 = Q6_V_vmux_QVV(mask.val[1], block2, block3);
result.val[0] = Q6_V_vmux_QVV(mask.val[2], block01, block23);
result.val[3] = Q6_V_vmux_QVV(mask.val[0], block4, block5); // block45
}
{
HVX_Vector scale23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2);
HVX_Vector scale45 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 4);
HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks);
scale23 = Q6_Vb_vshuff_Vb(scale23);
scale45 = Q6_Vb_vshuff_Vb(scale45);
result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0);
result.val[2] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale23, 0);
result.val[4] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale45, 0);
}
return result;
}
inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) {
// TODO: use intrinsics // TODO: use intrinsics
if (j < 4) { if (j < 4) {
@ -441,72 +486,101 @@ void dequantize_row_q8_0(const void * src, hexagon::dequant_output_type * dst, s
} }
} }
template <bool _IsDstAligned>
inline void dequantize_row_q4_0_2blocks(HVX_Vector qs,
HVX_Vector scale01,
HVX_Vector table,
hexagon::dequant_output_type * dst_ptr) {
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2));
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(qp0), scale01);
q_lo = Q6_Vhf_equals_Vqf16(q_lo);
if constexpr (_IsDstAligned) {
*reinterpret_cast<HVX_Vector *>(dst_ptr) = q_lo;
} else {
*reinterpret_cast<HVX_UVector *>(dst_ptr) = q_lo;
}
}
template <bool _IsDstAligned>
inline void dequantize_row_q4_0_4blocks(HVX_Vector qs,
HVX_Vector scale01,
HVX_Vector scale23,
HVX_Vector table,
hexagon::dequant_output_type * dst_ptr) {
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
HVX_Vector q_lo = qs;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4));
q_lo = Q6_V_lo_W(qp0);
q_lo = Q6_Vb_vshuff_Vb(q_lo);
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_V_lo_W(qp0);
q_hi = Q6_V_hi_W(qp0);
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scale01);
q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scale23);
q_lo = Q6_Vhf_equals_Vqf16(q_lo);
q_hi = Q6_Vhf_equals_Vqf16(q_hi);
if constexpr (_IsDstAligned) {
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = q_lo;
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = q_hi;
} else {
reinterpret_cast<HVX_UVector *>(dst_ptr)[0] = q_lo;
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = q_hi;
}
}
template <bool _IsDstAligned> template <bool _IsDstAligned>
void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector table) { void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector table) {
constexpr const int qk = QUANT_BLOCK_SIZE; constexpr const size_t kElemsPerVec = hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type);
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
constexpr const int qk = QUANT_BLOCK_SIZE;
static_assert(qk % 2 == 0, "qk must be even"); static_assert(qk % 2 == 0, "qk must be even");
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
static const auto load_masks = make_quad_block_mask<npu_device_block_q4_0>(); static const auto load_masks = make_quad_block_mask<npu_device_block_q4_0>();
alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices = alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices =
make_scale_load_mask<npu_device_block_q4_0>(); make_scale_load_mask<npu_device_block_q4_0>();
const int nb = count / qk; const int nb = count / qk;
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src); const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src);
hexagon::dequant_output_type * dst_ptr = dst; // TODO: opt for aligned access hexagon::dequant_output_type * dst_ptr = dst; // TODO: opt for aligned access
int i = 0; int i = 0;
for (; i + 5 < nb; i += 6) {
auto qs = load_hexa_block_generic(src_ptr + i, load_masks, scale_indices);
dequantize_row_q4_0_4blocks<_IsDstAligned>(qs.val[0], qs.val[1], qs.val[2], table, dst_ptr);
dequantize_row_q4_0_2blocks<_IsDstAligned>(qs.val[3], qs.val[4], table, dst_ptr + kElemsPerVec * 2);
dst_ptr += kElemsPerVec * 3;
}
for (; i + 3 < nb; i += 4) { for (; i + 3 < nb; i += 4) {
auto qs = load_qual_block_generic(src_ptr + i, load_masks, scale_indices); auto qs = load_qual_block_generic(src_ptr + i, load_masks, scale_indices);
dequantize_row_q4_0_4blocks<_IsDstAligned>(qs.val[0], qs.val[1], qs.val[2], table, dst_ptr);
HVX_Vector q_lo = qs.val[0]; dst_ptr += kElemsPerVec * 2;
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs.val[0], 4);
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4));
q_lo = Q6_Vb_vshuff_Vb(Q6_V_lo_W(qp0));
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_V_lo_W(qp0);
q_hi = Q6_V_hi_W(qp0);
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, qs.val[1]);
q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, qs.val[2]);
q_lo = Q6_Vhf_equals_Vqf16(q_lo);
q_hi = Q6_Vhf_equals_Vqf16(q_hi);
if constexpr (_IsDstAligned) {
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = q_lo;
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = q_hi;
} else {
reinterpret_cast<HVX_UVector *>(dst_ptr)[0] = q_lo;
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = q_hi;
}
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type) * 2;
} }
for (; i + 1 < nb; i += 2) { for (; i + 1 < nb; i += 2) {
auto qs = load_dual_block_generic(src_ptr + i, load_masks.val[0], scale_indices); auto qs = load_dual_block_generic(src_ptr + i, load_masks.val[0], scale_indices);
HVX_Vector q_lo = qs.val[0]; dequantize_row_q4_0_2blocks<_IsDstAligned>(qs.val[0], qs.val[1], table, dst_ptr);
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs.val[0], 4); dst_ptr += kElemsPerVec;
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2));
q_lo = Q6_Vb_vshuff_Vb(Q6_V_lo_W(qp0));
qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0);
q_lo = Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(qp0), qs.val[1]);
q_lo = Q6_Vhf_equals_Vqf16(q_lo);
if constexpr (_IsDstAligned) {
*reinterpret_cast<HVX_Vector *>(dst_ptr) = q_lo;
} else {
*reinterpret_cast<HVX_UVector *>(dst_ptr) = q_lo;
}
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type);
} }
if (i < nb) { if (i < nb) {

View File

@ -18,6 +18,7 @@ template <typename T, int N> struct HEXAGON_pack {
using HVX_Vector_x2 = HEXAGON_pack<HVX_Vector, 2>; using HVX_Vector_x2 = HEXAGON_pack<HVX_Vector, 2>;
using HVX_Vector_x3 = HEXAGON_pack<HVX_Vector, 3>; using HVX_Vector_x3 = HEXAGON_pack<HVX_Vector, 3>;
using HVX_Vector_x4 = HEXAGON_pack<HVX_Vector, 4>; using HVX_Vector_x4 = HEXAGON_pack<HVX_Vector, 4>;
using HVX_Vector_x5 = HEXAGON_pack<HVX_Vector, 5>;
using HVX_VectorPair_x4 = HEXAGON_pack<HVX_VectorPair, 4>; using HVX_VectorPair_x4 = HEXAGON_pack<HVX_VectorPair, 4>;
using HVX_VectorPred_x3 = HEXAGON_pack<HVX_VectorPred, 3>; using HVX_VectorPred_x3 = HEXAGON_pack<HVX_VectorPred, 3>;