feat: perf opt dma (#56)
* Add power management utilities to NPU device context and update DCVS settings * Update DCVS settings in power_utils to use v3 API and enhance power management * wip * Enhance dequantization functions by adding load_dequant_table support and updating signatures for improved performance * use lut * wip * fix test failure * wip * Refactor load_qual_block_generic to improve block handling and optimize vector operations * Enhance load_dual_block_generic and load_qual_block_generic to accept a mask parameter for improved block handling * Refactor flash_attn_impl to optimize mask l2 prefetch * wip * wip * wip * wip * add log * link against shared libraries instead of static ones * fix swiglu * wip * refactor expf_fix to handle overflow for different data types * enhance is_glu_op_supported to validate shapes for multiple sources * wip * refactor logging macros to use hexagon namespace and improve formatting * fix printf format error * wip * refactor: update static_assert messages for block size validation and add HVX_VectorPred_x3 type alias * rename * feat: enhance fa with mask * wip * wip * refactor: replace instances of Q6_V_vzero() with kZeroV for consistency * wip * wip * wip * fix: improve address alignment check in HVX_Vector handling * refactor: streamline vector dot product implementations for improved readability * refactor: q4k add hvx intrinsic impl * refactor: enhance dequantize_row_q4_K for clarity and performance * refactor: optimize scale mask usage in dequantization functions for improved performance * refactor: optimize dequantize_row_q4_K for intrinsic usage and performance improvements * refactor: move GLU operation implementation into separated file * sync after swiglu * wip * wip * wip * feat: increase prc main thread stack size * fix: replace hardcoded stack size with NPU_THREAD_STACK_SIZE constant * wip * feat: add optimized vector operations for exponential and division with overflow handling * wip * feat: refactor exponential function to handle overflow and underflow with improved logic * wip * wip * feat: add vector loading and scaling functions for improved performance in block processing * wip * feat: optimize block loading by refactoring scale index handling for improved performance * use Q6_Vb_vlut32_VbVbR_nomatch instead * feat: enhance scale loading by adding static assertion and restructuring block handling * wip * feat: refactor vec_dot_product_mixed_impl for improved clarity and performance * wip * feat: simplify vector loading functions and improve alignment handling * wip * feat: enhance scale loading mask with quantization block size validation * wip * feat: implement make_scale_load_mask function and refactor vector handling in vec_ops * feat: enhance load_dual_block_generic to include scale indices for improved vector loading * revert q8 dequant * wip * feat: optimize dequantization functions by removing unnecessary masking and updating lookup methods * wip * wip * add qurt_mutex * Add DMA transfer class and integrate into thread pool * Enhance DMA transfer functionality by adding support for multiple descriptors and initiating transfers in parallel * fix dma crash * fix failed unit tests * wip * use alignas * Improve DMA transfer error handling and update descriptor completion check * Fix VTCM cache size calculation in element-wise operations * Add cache clean operations before DMA transfers in element-wise operations * reduce cache clean operations * Refactor DMA transfer functions to support 1D operations and rename for clarity * Enhance DMA transfer functionality by adding 2D submission support and improving descriptor initialization * Update read buffer method to support forced invalidation and remove unnecessary invalidation calls in element-wise operations * wip * Improve DMA transfer handling in mul_mat_gemv_impl by replacing memcpy with initiate_dma_row_transfer and adding wait_for_dma logic * fix 2d dma * feat: add DMA plane cache * rename * wip * use memcpy for debug * fix cache plane calc * refactor: remove debug logging from mul_mat_impl and optimize cache handling * rename * fix 2d dma type * refactor: enhance DMA transfer handling in mul_mat_gemv_impl and wait functions * refactor: optimize DMA transfer handling in mul_mat_gemv_impl and wait functions * wip * wip * move op impl into sub dir * add log * fix: correct pointer usage in mul_mat_gemv_impl for next plane access * fix: improve DMA transfer error handling in mul_mat_impl and mul_mat_gemv_impl * fix: fix crash by using the entire row bytes * wip * wip * fix: prevent parallelization for scalar src1 in is_mul_mat_supported * fix: add dimension checks for 2D DMA transfers and fallback to 1D if necessary * wip * fix: enable thread barrier for mul multiplication operations * feat: add synchronization checks for tensor operations and update related functions * wip * fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations * Revert "fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations" This reverts commit af3441e67e706b2e5122369dc160353796867dd3. * wip * wip * add comment * wip
This commit is contained in:
parent
3686cb5fea
commit
3994a9b7df
|
|
@ -187,11 +187,15 @@ else()
|
|||
|
||||
file(GLOB common_srcs "${CMAKE_CURRENT_LIST_DIR}/common/*.cpp")
|
||||
file(GLOB device_srcs "${CMAKE_CURRENT_LIST_DIR}/device/*.cpp")
|
||||
file(GLOB device_op_srcs "${CMAKE_CURRENT_LIST_DIR}/device/op/*.cpp")
|
||||
file(GLOB dma_srcs "${HEXAGON_SDK_ROOT}/addons/compute/libs/userdma/utils_lib/src/*.c")
|
||||
set(skel_srcs "${CMAKE_CURRENT_BINARY_DIR}/npu_device_skel.c")
|
||||
add_library(hexagon_npu_skel_OBJS OBJECT
|
||||
${common_srcs}
|
||||
${device_srcs}
|
||||
${device_op_srcs}
|
||||
${skel_srcs}
|
||||
${dma_srcs}
|
||||
)
|
||||
|
||||
if(CMAKE_BUILD_TYPE MATCHES "Debug|Dbg")
|
||||
|
|
@ -237,6 +241,12 @@ else()
|
|||
add_subdirectory(${HEXAGON_SDK_ROOT}/libs/qprintf qprintf_dir)
|
||||
target_include_directories(hexagon_npu_skel_OBJS PUBLIC
|
||||
${HEXAGON_SDK_ROOT}/libs/qprintf/inc/
|
||||
|
||||
# TODO: find a better way to include these
|
||||
${HEXAGON_SDK_ROOT}/addons/compute/libs/userdma/utils_lib/api/
|
||||
${HEXAGON_SDK_ROOT}/addons/compute/libs/userdma/utils_lib/inc/
|
||||
${CMAKE_CURRENT_LIST_DIR}/device/
|
||||
${CMAKE_CURRENT_LIST_DIR}/device/op/
|
||||
)
|
||||
|
||||
# disable warnings for the skel
|
||||
|
|
@ -257,5 +267,3 @@ else()
|
|||
|
||||
copy_binaries(hexagon_npu_skel)
|
||||
endif()
|
||||
|
||||
# vim: set noet fenc=utf-8 ff=unix ft=cmake :
|
||||
|
|
|
|||
|
|
@ -0,0 +1,184 @@
|
|||
#include "dma_transfer.hpp"
|
||||
|
||||
#include <dma_desc.h>
|
||||
#include <qurt.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstdlib>
|
||||
|
||||
namespace hexagon::dma {
|
||||
|
||||
dma_transfer::dma_transfer() {
|
||||
dma_desc_set_next(_dma_1d_desc0, 0);
|
||||
dma_desc_set_dstate(_dma_1d_desc0, DESC_DSTATE_INCOMPLETE);
|
||||
dma_desc_set_desctype(_dma_1d_desc0, DMA_DESC_TYPE_1D);
|
||||
dma_desc_set_order(_dma_1d_desc0, DESC_ORDER_ORDER);
|
||||
dma_desc_set_bypasssrc(_dma_1d_desc0, DESC_BYPASS_ON); // for dram
|
||||
dma_desc_set_bypassdst(_dma_1d_desc0, DESC_BYPASS_OFF); // for vtcm
|
||||
dma_desc_set_length(_dma_1d_desc0, 0);
|
||||
|
||||
dma_desc_set_next(_dma_1d_desc1, 0);
|
||||
dma_desc_set_dstate(_dma_1d_desc1, DESC_DSTATE_INCOMPLETE);
|
||||
dma_desc_set_desctype(_dma_1d_desc1, DMA_DESC_TYPE_1D);
|
||||
dma_desc_set_order(_dma_1d_desc1, DESC_ORDER_ORDER);
|
||||
dma_desc_set_bypasssrc(_dma_1d_desc1, DESC_BYPASS_ON); // for dram
|
||||
dma_desc_set_bypassdst(_dma_1d_desc1, DESC_BYPASS_OFF); // for vtcm
|
||||
dma_desc_set_length(_dma_1d_desc1, 0);
|
||||
|
||||
dma_desc_set_next(_dma_2d_desc0, 0);
|
||||
dma_desc_set_dstate(_dma_2d_desc0, DESC_DSTATE_INCOMPLETE);
|
||||
dma_desc_set_desctype(_dma_2d_desc0, DMA_DESC_TYPE_2D);
|
||||
dma_desc_set_order(_dma_2d_desc0, DESC_ORDER_ORDER);
|
||||
dma_desc_set_bypasssrc(_dma_2d_desc0, DESC_BYPASS_ON); // for dram
|
||||
dma_desc_set_bypassdst(_dma_2d_desc0, DESC_BYPASS_OFF); // for vtcm
|
||||
dma_desc_set_cachealloc(_dma_2d_desc0, DESC_CACHEALLOC_NONE);
|
||||
dma_desc_set_roiwidth(_dma_2d_desc0, 0);
|
||||
dma_desc_set_roiheight(_dma_2d_desc0, 0);
|
||||
dma_desc_set_srcstride(_dma_2d_desc0, 0);
|
||||
dma_desc_set_dststride(_dma_2d_desc0, 0);
|
||||
dma_desc_set_srcwidthoffset(_dma_2d_desc0, 0);
|
||||
dma_desc_set_dstwidthoffset(_dma_2d_desc0, 0);
|
||||
}
|
||||
|
||||
dma_transfer::~dma_transfer() {
|
||||
wait();
|
||||
}
|
||||
|
||||
bool dma_transfer::submit1d(const uint8_t * src, uint8_t * dst, size_t size) {
|
||||
constexpr size_t kMaxDmaTransferSize = DESC_LENGTH_MASK;
|
||||
if (size > kMaxDmaTransferSize) {
|
||||
// TODO: support chained descriptors for large transfers
|
||||
DEVICE_LOG_ERROR("dma_transfer::submit1d, size(%zu) is too large\n", size);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!dma_transfer::is_desc_done(_dma_1d_desc0)) {
|
||||
DEVICE_LOG_ERROR("Failed to initiate DMA transfer for one or more descriptors\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
dma_desc_set_next(_dma_1d_desc0, 0);
|
||||
dma_desc_set_dstate(_dma_1d_desc0, DESC_DSTATE_INCOMPLETE);
|
||||
dma_desc_set_src(_dma_1d_desc0, reinterpret_cast<uint32_t>(src));
|
||||
dma_desc_set_dst(_dma_1d_desc0, reinterpret_cast<uint32_t>(dst));
|
||||
dma_desc_set_length(_dma_1d_desc0, size);
|
||||
|
||||
void * buffs[] = { _dma_1d_desc0 };
|
||||
if (!submit_impl(buffs, std::size(buffs))) {
|
||||
DEVICE_LOG_ERROR("Failed to submit DMA descriptor\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("dma_transfer::submit1d, src(%p), dst(%p), size(%zu), desc(%p)\n", (void *) src, (void *) dst,
|
||||
size, (void *) _dma_1d_desc0);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool dma_transfer::submit1d(const uint8_t * src0, uint8_t * dst0, const uint8_t * src1, uint8_t * dst1, size_t size) {
|
||||
constexpr size_t kMaxDmaTransferSize = DESC_LENGTH_MASK;
|
||||
if (size > kMaxDmaTransferSize) {
|
||||
// TODO: support chained descriptors for large transfers
|
||||
DEVICE_LOG_ERROR("dma_transfer::submit1d, size(%zu) is too large\n", size);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!dma_transfer::is_desc_done(_dma_1d_desc0) || !dma_transfer::is_desc_done(_dma_1d_desc1)) {
|
||||
DEVICE_LOG_ERROR("Failed to initiate DMA transfer for one or more descriptors\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
dma_desc_set_next(_dma_1d_desc0, 0);
|
||||
dma_desc_set_dstate(_dma_1d_desc0, DESC_DSTATE_INCOMPLETE);
|
||||
dma_desc_set_src(_dma_1d_desc0, reinterpret_cast<uint32_t>(src0));
|
||||
dma_desc_set_dst(_dma_1d_desc0, reinterpret_cast<uint32_t>(dst0));
|
||||
dma_desc_set_length(_dma_1d_desc0, size);
|
||||
|
||||
dma_desc_set_next(_dma_1d_desc1, 0);
|
||||
dma_desc_set_dstate(_dma_1d_desc1, DESC_DSTATE_INCOMPLETE);
|
||||
dma_desc_set_src(_dma_1d_desc1, reinterpret_cast<uint32_t>(src1));
|
||||
dma_desc_set_dst(_dma_1d_desc1, reinterpret_cast<uint32_t>(dst1));
|
||||
dma_desc_set_length(_dma_1d_desc1, size);
|
||||
|
||||
void * buffs[] = { _dma_1d_desc0, _dma_1d_desc1 };
|
||||
if (!submit_impl(buffs, std::size(buffs))) {
|
||||
DEVICE_LOG_ERROR("Failed to submit DMA descriptor\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG(
|
||||
"dma_transfer::submit1d, src0(%p), dst0(%p), src1(%p), dst1(%p), size(%zu), desc0(%p), desc1(%p)\n",
|
||||
(void *) src0, (void *) dst0, (void *) src1, (void *) dst1, size, (void *) _dma_1d_desc0,
|
||||
(void *) _dma_1d_desc1);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool dma_transfer::submit2d(const uint8_t * src,
|
||||
uint8_t * dst,
|
||||
size_t width,
|
||||
size_t height,
|
||||
size_t src_stride,
|
||||
size_t dst_stride) {
|
||||
// Note that the dma only supports 16-bit width and height for 2D transfer, see also: DESC_ROIWIDTH_MASK
|
||||
constexpr size_t kMaxDmaTransferSize = DESC_ROIWIDTH_MASK;
|
||||
if (width > kMaxDmaTransferSize || height > kMaxDmaTransferSize || src_stride > kMaxDmaTransferSize ||
|
||||
dst_stride > kMaxDmaTransferSize) {
|
||||
if (src_stride != dst_stride) {
|
||||
// TODO: support chained descriptors for large transfers
|
||||
DEVICE_LOG_ERROR("dma_transfer::submit2d, src_stride(%zu) or dst_stride(%zu) is too large\n", src_stride,
|
||||
dst_stride);
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("dma_transfer::submit2d, width(%zu) or height(%zu) is too large, fallback to 1D transfer\n",
|
||||
width, height);
|
||||
return submit1d(src, dst, src_stride * height);
|
||||
}
|
||||
|
||||
if (!dma_transfer::is_desc_done(_dma_2d_desc0)) {
|
||||
DEVICE_LOG_ERROR("Failed to initiate DMA transfer for one or more descriptors\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
dma_desc_set_next(_dma_2d_desc0, 0);
|
||||
dma_desc_set_dstate(_dma_2d_desc0, DESC_DSTATE_INCOMPLETE);
|
||||
dma_desc_set_src(_dma_2d_desc0, reinterpret_cast<uint32_t>(src));
|
||||
dma_desc_set_dst(_dma_2d_desc0, reinterpret_cast<uint32_t>(dst));
|
||||
dma_desc_set_roiwidth(_dma_2d_desc0, width);
|
||||
dma_desc_set_roiheight(_dma_2d_desc0, height);
|
||||
dma_desc_set_srcstride(_dma_2d_desc0, src_stride);
|
||||
dma_desc_set_dststride(_dma_2d_desc0, dst_stride);
|
||||
|
||||
void * buffs[] = { _dma_2d_desc0 };
|
||||
if (!submit_impl(buffs, std::size(buffs))) {
|
||||
DEVICE_LOG_ERROR("Failed to submit DMA descriptor\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG(
|
||||
"dma_transfer::submit2d, src(%p), dst(%p), width(%zu), height(%zu), src_stride(%zu), dst_stride(%zu), "
|
||||
"desc(%p)\n",
|
||||
(void *) src, (void *) dst, width, height, src_stride, dst_stride, (void *) _dma_2d_desc0);
|
||||
return true;
|
||||
}
|
||||
|
||||
void dma_transfer::wait() {
|
||||
auto ret = dma_wait_for_idle();
|
||||
if (ret != DMA_SUCCESS) {
|
||||
DEVICE_LOG_ERROR("dma_transfer: failed to wait for DMA idle: %d\n", ret);
|
||||
}
|
||||
}
|
||||
|
||||
bool dma_transfer::is_desc_done(uint8_t * desc) {
|
||||
return !dma_desc_get_src(desc) || dma_desc_is_done(desc) == DMA_COMPLETE;
|
||||
}
|
||||
|
||||
bool dma_transfer::submit_impl(void ** desc_batch, int batch_len) {
|
||||
_dma_desc_mutex.lock();
|
||||
const bool succ = dma_desc_submit(desc_batch, batch_len) == DMA_SUCCESS;
|
||||
_dma_desc_mutex.unlock();
|
||||
return succ;
|
||||
}
|
||||
|
||||
qurt_mutex dma_transfer::_dma_desc_mutex;
|
||||
|
||||
} // namespace hexagon::dma
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
#pragma once
|
||||
|
||||
#include "util.hpp"
|
||||
|
||||
#include <dma_utils.h>
|
||||
|
||||
namespace hexagon::dma {
|
||||
|
||||
class dma_transfer {
|
||||
public:
|
||||
dma_transfer();
|
||||
~dma_transfer();
|
||||
|
||||
/**
|
||||
* Submits a 1D DMA transfer.
|
||||
*
|
||||
* Limitations:
|
||||
* - The maximum supported transfer size is kMaxDmaTransferSize (DESC_LENGTH_MASK, 24bit).
|
||||
* - Transfers larger than this size are not supported and will fail.
|
||||
* - Large transfers must be split into multiple smaller transfers by the caller.
|
||||
*/
|
||||
bool submit1d(const uint8_t * src, uint8_t * dst, size_t size);
|
||||
bool submit1d(const uint8_t * src0, uint8_t * dst0, const uint8_t * src1, uint8_t * dst1, size_t size);
|
||||
bool submit2d(const uint8_t * src,
|
||||
uint8_t * dst,
|
||||
size_t width,
|
||||
size_t height,
|
||||
size_t src_stride,
|
||||
size_t dst_stride);
|
||||
void wait();
|
||||
|
||||
private:
|
||||
static bool is_desc_done(uint8_t * desc); // TODO: should we use void * here?
|
||||
static qurt_mutex _dma_desc_mutex;
|
||||
|
||||
bool submit_impl(void ** desc_batch, int batch_len);
|
||||
|
||||
alignas(DMA_DESC_SIZE_1D) uint8_t _dma_1d_desc0[DMA_DESC_SIZE_1D] = {};
|
||||
alignas(DMA_DESC_SIZE_1D) uint8_t _dma_1d_desc1[DMA_DESC_SIZE_1D] = {};
|
||||
alignas(DMA_DESC_SIZE_2D) uint8_t _dma_2d_desc0[DMA_DESC_SIZE_2D] = {};
|
||||
|
||||
DISABLE_COPY_AND_MOVE(dma_transfer);
|
||||
};
|
||||
|
||||
} // namespace hexagon::dma
|
||||
|
|
@ -30,12 +30,8 @@ void graph::set_tensor(const npu_device_tensor_handle_t * tensors, int tensor_co
|
|||
for (int i = 0; i < tensor_count; ++i) {
|
||||
auto * tensor_obj = reinterpret_cast<tensor *>(tensors[i]);
|
||||
_tensors[i] = tensor_obj;
|
||||
DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n",
|
||||
(void *) this,
|
||||
i,
|
||||
(void *) tensor_obj,
|
||||
(void *) tensor_obj->get_src(0),
|
||||
(void *) tensor_obj->get_src(1),
|
||||
DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n", (void *) this, i, (void *) tensor_obj,
|
||||
(void *) tensor_obj->get_src(0), (void *) tensor_obj->get_src(1),
|
||||
op_get_name(tensor_obj->get_op()));
|
||||
}
|
||||
|
||||
|
|
@ -77,27 +73,28 @@ void graph::thread_pool_task(default_thread_pool * pool,
|
|||
void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params) {
|
||||
hexagon::compute_params params = { thread_params, _f16_to_f32_table };
|
||||
|
||||
npu_device_tensor_op prev_op = NPU_OP_COUNT;
|
||||
|
||||
for (size_t i = 0; i < _tensor_count; ++i) {
|
||||
auto * dst = _tensors[i];
|
||||
auto op = dst->get_op();
|
||||
auto * func = get_compute_func(dst);
|
||||
if (func == nullptr) {
|
||||
if (!func) {
|
||||
DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d not supported\n", (void *) this, i, op);
|
||||
return;
|
||||
}
|
||||
if (!func(dst, ¶ms)) {
|
||||
DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d compute failed\n", (void *) this, i, op);
|
||||
|
||||
const bool should_sync = requires_thread_barrier(prev_op, op);
|
||||
if (pool && should_sync) {
|
||||
// For the last tensor, the thread pool will handle synchronization
|
||||
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this,
|
||||
params.get_thread_index(), i, _tensor_count);
|
||||
pool->sync_thread();
|
||||
}
|
||||
|
||||
const bool should_sync = requires_thread_barrier(op);
|
||||
if (pool && should_sync && i < _tensor_count - 1) {
|
||||
// For the last tensor, the thread pool will handle synchronization
|
||||
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]",
|
||||
(void *) this,
|
||||
params.get_thread_index(),
|
||||
i,
|
||||
_tensor_count);
|
||||
pool->sync_thread();
|
||||
prev_op = op;
|
||||
if (!func(dst, ¶ms)) {
|
||||
DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d compute failed\n", (void *) this, i, op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,8 +28,7 @@ void flash_attn_impl(hexagon::tensor * out,
|
|||
|
||||
if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n",
|
||||
hexagon::get_type_name(k->get_type()),
|
||||
hexagon::get_type_name(v->get_type()));
|
||||
hexagon::get_type_name(k->get_type()), hexagon::get_type_name(v->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -92,8 +91,7 @@ void flash_attn_impl(hexagon::tensor * out,
|
|||
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) out,
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return;
|
||||
}
|
||||
|
|
@ -114,6 +112,9 @@ void flash_attn_impl(hexagon::tensor * out,
|
|||
const auto iq2 = (ir - iq3 * rows_per_batch) / q->get_ne(1);
|
||||
const auto iq1 = (ir - iq3 * rows_per_batch - iq2 * q->get_ne(1));
|
||||
|
||||
const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3));
|
||||
hexagon::l2fetch_row(q_data, row_bytes_q);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope =
|
||||
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
|
||||
|
|
@ -121,9 +122,6 @@ void flash_attn_impl(hexagon::tensor * out,
|
|||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3));
|
||||
hexagon::l2fetch_row(q_data, row_bytes_q);
|
||||
|
||||
if constexpr (is_v_f16) {
|
||||
memset(VKQ16, 0, DV * sizeof(npu_device_fp16_t));
|
||||
} else {
|
||||
|
|
@ -341,9 +339,7 @@ bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
|
||||
const auto * v = &srcs[2];
|
||||
if (v->type != k->type) { // TODO: support more v types
|
||||
DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n",
|
||||
op_get_name(op),
|
||||
get_type_name(v->type),
|
||||
DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n", op_get_name(op), get_type_name(v->type),
|
||||
get_type_name(k->type));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -358,48 +354,31 @@ bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
DEVICE_LOG_DEBUG(
|
||||
"[%s]dst shape does not match q and v: dst ne: %lld, %lld, %lld, %lld, q ne: %lld, %lld, %lld, %lld, "
|
||||
"v ne: %lld, %lld, %lld, %lld\n",
|
||||
op_get_name(op),
|
||||
dst->ne[0],
|
||||
dst->ne[1],
|
||||
dst->ne[2],
|
||||
dst->ne[3],
|
||||
q->ne[0],
|
||||
q->ne[1],
|
||||
q->ne[2],
|
||||
q->ne[3],
|
||||
v->ne[0],
|
||||
v->ne[1],
|
||||
v->ne[2],
|
||||
v->ne[3]);
|
||||
op_get_name(op), dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3],
|
||||
v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_transposed_or_permuted(dst->nb)) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n",
|
||||
op_get_name(op),
|
||||
(size_t) dst->nb[0],
|
||||
(size_t) dst->nb[1],
|
||||
(size_t) dst->nb[2],
|
||||
(size_t) dst->nb[3]);
|
||||
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n", op_get_name(op),
|
||||
(size_t) dst->nb[0], (size_t) dst->nb[1], (size_t) dst->nb[2], (size_t) dst->nb[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (q->ne[0] != k->ne[0]) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]q and k shapes do not match: q ne: %lld, %lld, %lld, %lld, k ne: %lld, %lld, %lld, %lld\n",
|
||||
op_get_name(op),
|
||||
q->ne[0],
|
||||
q->ne[1],
|
||||
q->ne[2],
|
||||
q->ne[3],
|
||||
k->ne[0],
|
||||
k->ne[1],
|
||||
k->ne[2],
|
||||
k->ne[3]);
|
||||
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_flash_attn_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
|
||||
NPU_UNUSED(op);
|
||||
NPU_UNUSED(next_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -9,5 +9,6 @@ bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
bool is_flash_attn_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -225,4 +225,9 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool is_glu_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
|
||||
NPU_UNUSED(op);
|
||||
return next_op == NPU_OP_MUL_MAT;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -11,5 +11,6 @@ bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
bool is_glu_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -87,46 +87,89 @@ template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::co
|
|||
return false;
|
||||
}
|
||||
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
|
||||
const auto rows_per_cube = out->get_ne(2) * out->get_ne(1);
|
||||
const auto start_end = params->get_work_slice(total_rows);
|
||||
const auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
|
||||
const auto start_end = params->get_work_slice(total_rows);
|
||||
if (start_end.first >= start_end.second) {
|
||||
return true;
|
||||
}
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
|
||||
const auto src_row_bytes = src0->get_ne(0) * sizeof(data_type);
|
||||
const auto src_row_bytes_aligned = hexagon::get_aligned_size(src_row_bytes);
|
||||
uint8_t * src_cache_ptr = params->get_vtcm_cache(src_row_bytes_aligned * 4);
|
||||
if (!src_cache_ptr) {
|
||||
DEVICE_LOG_ERROR("element_wise_op: failed to get VTCM cache, size: %zu\n", size_t(src_row_bytes_aligned * 4));
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type);
|
||||
for (int64_t ir = start_end.first; ir < start_end.second; ++ir) {
|
||||
const auto i03 = ir / rows_per_cube;
|
||||
const auto i02 = ir / out->get_ne(1) - i03 * out->get_ne(2);
|
||||
const auto i01 = ir % out->get_ne(1); // TODO: should we use divide instead of mod?
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer(true); // TODO: avoid invalidation
|
||||
const auto rows_per_cube = out->get_ne(2) * out->get_ne(1);
|
||||
|
||||
uint8_t * src0_read_cache_ptr = src_cache_ptr;
|
||||
uint8_t * src0_write_cache_ptr = src_cache_ptr + src_row_bytes_aligned;
|
||||
uint8_t * src1_read_cache_ptr = src_cache_ptr + src_row_bytes_aligned * 2;
|
||||
uint8_t * src1_write_cache_ptr = src_cache_ptr + src_row_bytes_aligned * 3;
|
||||
|
||||
{
|
||||
const auto i03 = start_end.first / rows_per_cube;
|
||||
const auto i02 = start_end.first / out->get_ne(1) - i03 * out->get_ne(2);
|
||||
const auto i01 = start_end.first % out->get_ne(1); // TODO: should we use divide instead of mod?
|
||||
const auto i13 = i03 % src1->get_ne(3);
|
||||
const auto i12 = i02 % src1->get_ne(2);
|
||||
const auto i11 = i01 % src1->get_ne(1);
|
||||
|
||||
auto * src1_plane = src1_ptr + i13 * src1->get_nb(3) + i12 * src1->get_nb(2);
|
||||
auto * src0_row = src0_ptr + i03 * src0->get_nb(3) + i02 * src0->get_nb(2) + i01 * src0->get_nb(1);
|
||||
auto * src1_row = src1_plane + i11 * src1->get_nb(1);
|
||||
auto * dst_row = dst_ptr + i03 * out->get_nb(3) + i02 * out->get_nb(2) + i01 * out->get_nb(1);
|
||||
if (ir + 1 < start_end.second) {
|
||||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
|
||||
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes);
|
||||
auto * src0_row = src0_ptr + i03 * src0->get_nb(3) + i02 * src0->get_nb(2) + i01 * src0->get_nb(1);
|
||||
auto * src1_row = src1_ptr + i13 * src1->get_nb(3) + i12 * src1->get_nb(2) + i11 * src1->get_nb(1);
|
||||
if (!params->initiate_dma_row_transfer(src0_row, src0_write_cache_ptr, src1_row, src1_write_cache_ptr,
|
||||
src_row_bytes)) {
|
||||
DEVICE_LOG_ERROR("element_wise_op: failed to initiate dma transfer\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
|
||||
|
||||
for (int64_t ir = start_end.first; ir < start_end.second; ++ir) {
|
||||
const auto i03 = ir / rows_per_cube;
|
||||
const auto i02 = ir / out->get_ne(1) - i03 * out->get_ne(2);
|
||||
const auto i01 = ir % out->get_ne(1); // TODO: should we use divide instead of mod?
|
||||
const auto ir_next = ir + 1;
|
||||
|
||||
auto * dst_row = dst_ptr + i03 * out->get_nb(3) + i02 * out->get_nb(2) + i01 * out->get_nb(1);
|
||||
{
|
||||
std::swap(src0_read_cache_ptr, src0_write_cache_ptr);
|
||||
std::swap(src1_read_cache_ptr, src1_write_cache_ptr);
|
||||
params->wait_for_dma();
|
||||
}
|
||||
|
||||
_RowFunc(reinterpret_cast<const data_type *>(src0_row),
|
||||
reinterpret_cast<const data_type *>(src1_row),
|
||||
reinterpret_cast<data_type *>(dst_row),
|
||||
if (ir_next < start_end.second) {
|
||||
const auto i03_next = ir_next / rows_per_cube;
|
||||
const auto i02_next = ir_next / out->get_ne(1) - i03_next * out->get_ne(2);
|
||||
const auto i01_next = ir_next % out->get_ne(1);
|
||||
const auto i13_next = i03_next % src1->get_ne(3);
|
||||
const auto i12_next = i02_next % src1->get_ne(2);
|
||||
const auto i11_next = i01_next % src1->get_ne(1);
|
||||
|
||||
auto * src0_next_row =
|
||||
src0_ptr + i03_next * src0->get_nb(3) + i02_next * src0->get_nb(2) + i01_next * src0->get_nb(1);
|
||||
auto * src1_next_row =
|
||||
src1_ptr + i13_next * src1->get_nb(3) + i12_next * src1->get_nb(2) + i11_next * src1->get_nb(1);
|
||||
if (!params->initiate_dma_row_transfer(src0_next_row, src0_write_cache_ptr, src1_next_row,
|
||||
src1_write_cache_ptr, src_row_bytes)) {
|
||||
DEVICE_LOG_ERROR("element_wise_op: failed to continue DMA transfer\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
_RowFunc(reinterpret_cast<const data_type *>(src0_read_cache_ptr),
|
||||
reinterpret_cast<const data_type *>(src1_read_cache_ptr), reinterpret_cast<data_type *>(dst_row),
|
||||
static_cast<size_t>(out->get_ne(0)));
|
||||
}
|
||||
|
||||
|
|
@ -152,31 +195,27 @@ bool is_element_wise_op_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
const auto & src0 = srcs[0];
|
||||
const auto & src1 = srcs[1];
|
||||
if (dst->type != src0.type || dst->type != src1.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
|
||||
hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: fix FP16 add/sub
|
||||
if (dst->type == NPU_DATA_TYPE_F16 && op != NPU_OP_MUL) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0.ne[0] != src1.ne[0]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.ne[0] and src1.ne[0] not match: %ld vs %ld\n",
|
||||
hexagon::op_get_name(op),
|
||||
(long) src0.ne[0],
|
||||
(long) src1.ne[0]);
|
||||
DEVICE_LOG_DEBUG("[%s]src0.ne[0] and src1.ne[0] not match: %ld vs %ld\n", hexagon::op_get_name(op),
|
||||
(long) src0.ne[0], (long) src1.ne[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -188,6 +227,11 @@ bool is_element_wise_op_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool is_element_wise_op_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
|
||||
NPU_UNUSED(op);
|
||||
return next_op == NPU_OP_MUL_MAT;
|
||||
}
|
||||
|
||||
void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
|
||||
|
|
@ -220,7 +264,7 @@ void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) {
|
|||
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
|
||||
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum,
|
||||
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes));
|
||||
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes));
|
||||
}
|
||||
|
||||
const float mean = hexagon::vec_reduction_f32_qf32(sum) / count; // TODO: figure out how to do division in vector
|
||||
|
|
@ -245,8 +289,7 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
|
|||
|
||||
auto * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) out,
|
||||
DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -274,10 +317,8 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
|
|||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
|
||||
}
|
||||
|
||||
_RowFunc(reinterpret_cast<const data_type *>(src0_row),
|
||||
reinterpret_cast<data_type *>(dst_row),
|
||||
static_cast<size_t>(out->get_ne(0)),
|
||||
param);
|
||||
_RowFunc(reinterpret_cast<const data_type *>(src0_row), reinterpret_cast<data_type *>(dst_row),
|
||||
static_cast<size_t>(out->get_ne(0)), param);
|
||||
}
|
||||
|
||||
out->release_write_buffer(); // mark the output tensor as modified
|
||||
|
|
@ -301,16 +342,14 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
|
||||
const auto & src0 = srcs[0];
|
||||
if (dst->type != src0.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
|
||||
hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -322,70 +361,71 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool is_unary_op_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
|
||||
NPU_UNUSED(op);
|
||||
return next_op == NPU_OP_MUL_MAT;
|
||||
}
|
||||
|
||||
struct op_capabilities {
|
||||
npu_device_tensor_op op;
|
||||
hexagon::op_is_supported_func_type is_supported;
|
||||
hexagon::compute_func_type compute_funcs[NPU_DATA_TYPE_COUNT];
|
||||
bool requires_thread_barrier = false;
|
||||
npu_device_tensor_op op;
|
||||
hexagon::op_is_supported_func_type is_supported;
|
||||
hexagon::op_required_sync_func_type requires_thread_barrier_func;
|
||||
hexagon::compute_func_type compute_funcs[NPU_DATA_TYPE_COUNT];
|
||||
};
|
||||
|
||||
constexpr const op_capabilities kOpCapabilities[] = {
|
||||
{
|
||||
NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported,
|
||||
NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported,
|
||||
hexagon::is_mul_mat_required_sync,
|
||||
{
|
||||
hexagon::mul_mat_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, true, // requires_thread_barrier
|
||||
},
|
||||
}, },
|
||||
{
|
||||
NPU_OP_ADD, is_element_wise_op_supported,
|
||||
{
|
||||
NPU_OP_ADD, is_element_wise_op_supported,
|
||||
is_element_wise_op_required_sync, {
|
||||
element_wise_op<vec_op_f32_f32<vadd_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vadd_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, false,
|
||||
},
|
||||
}, },
|
||||
{
|
||||
NPU_OP_SUB, is_element_wise_op_supported,
|
||||
{
|
||||
is_element_wise_op_required_sync, {
|
||||
element_wise_op<vec_op_f32_f32<vsub_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vsub_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, false,
|
||||
},
|
||||
}, },
|
||||
{
|
||||
NPU_OP_MUL, is_element_wise_op_supported,
|
||||
{
|
||||
NPU_OP_MUL, is_element_wise_op_supported,
|
||||
is_element_wise_op_required_sync, {
|
||||
element_wise_op<vec_op_f32_f32<vmul_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vmul_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, false,
|
||||
},
|
||||
}, },
|
||||
{
|
||||
NPU_OP_RMS_NORM, is_unary_op_supported,
|
||||
{
|
||||
NPU_OP_RMS_NORM, is_unary_op_supported,
|
||||
is_unary_op_required_sync, {
|
||||
unary_op<rms_norm_vec_f32>, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, false,
|
||||
},
|
||||
}, },
|
||||
{
|
||||
NPU_OP_FLASH_ATTN, hexagon::is_flash_attn_supported,
|
||||
hexagon::is_flash_attn_required_sync,
|
||||
{
|
||||
hexagon::flash_attn_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, true, // requires_thread_barrier
|
||||
},
|
||||
}, },
|
||||
{
|
||||
NPU_OP_ROPE, hexagon::is_rope_supported,
|
||||
NPU_OP_ROPE, hexagon::is_rope_supported,
|
||||
hexagon::is_rope_required_sync,
|
||||
{
|
||||
hexagon::rope_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, false,
|
||||
},
|
||||
}, },
|
||||
{
|
||||
NPU_OP_GLU, hexagon::is_glu_op_supported,
|
||||
NPU_OP_GLU, hexagon::is_glu_op_supported,
|
||||
hexagon::is_glu_required_sync,
|
||||
{
|
||||
hexagon::glu_f32, // NPU_DATA_TYPE_F32
|
||||
hexagon::glu_f16, // NPU_DATA_TYPE_F16
|
||||
}, true, // TODO: should we avoid using thread barrier?
|
||||
},
|
||||
}, },
|
||||
};
|
||||
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
|
||||
|
|
@ -417,12 +457,13 @@ compute_func_type get_compute_func(tensor * dst) {
|
|||
return get_compute_func_impl(dst->get_op(), dst->get_type());
|
||||
}
|
||||
|
||||
bool requires_thread_barrier(npu_device_tensor_op op) {
|
||||
bool requires_thread_barrier(npu_device_tensor_op op, npu_device_tensor_op next_op) {
|
||||
if (op >= NPU_OP_COUNT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return kOpCapabilities[op].requires_thread_barrier;
|
||||
auto requires_thread_barrier_func = kOpCapabilities[op].requires_thread_barrier_func;
|
||||
return requires_thread_barrier_func && requires_thread_barrier_func(op, next_op);
|
||||
}
|
||||
|
||||
bool support_op(const npu_device_tensor_op_spec * op_spec,
|
||||
|
|
@ -442,8 +483,8 @@ bool support_op(const npu_device_tensor_op_spec * op_spec,
|
|||
}
|
||||
|
||||
if (get_compute_func_impl(op, dst->type) == nullptr) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op), get_type_name(dst->type));
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
|
||||
get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -6,7 +6,7 @@ namespace hexagon {
|
|||
|
||||
compute_func_type get_compute_func(tensor * dst);
|
||||
|
||||
bool requires_thread_barrier(npu_device_tensor_op op);
|
||||
bool requires_thread_barrier(npu_device_tensor_op op, npu_device_tensor_op next_op);
|
||||
|
||||
bool support_op(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
|
|
@ -0,0 +1,735 @@
|
|||
#include "op_mul_mat.hpp"
|
||||
|
||||
#include "thread_pool.hpp" // TODO: remove this dependency
|
||||
#include "type_traits.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename _T> struct get_data_type {};
|
||||
|
||||
template <typename _TData0, typename _TData1>
|
||||
struct get_data_type<HVX_Vector (*)(const _TData0 *, const _TData1 *, size_t)> {
|
||||
using data_type0 = _TData0;
|
||||
using data_type1 = _TData1;
|
||||
};
|
||||
|
||||
template <typename _TRet> struct convert_vector {};
|
||||
|
||||
template <> struct convert_vector<float> {
|
||||
static float convert(HVX_Vector vec) { return hexagon::get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec)); }
|
||||
};
|
||||
|
||||
template <> struct convert_vector<npu_device_fp16_t> {
|
||||
static float convert(HVX_Vector vec) {
|
||||
HVX_Vector vect = Q6_Vhf_equals_Vqf16(vec);
|
||||
uint16_t i = (vect[0] & 0xffff);
|
||||
return reinterpret_cast<__fp16 &>(i);
|
||||
}
|
||||
};
|
||||
|
||||
template <auto _DotFunc>
|
||||
inline void batched_row_dot(const uint8_t * src0_plane,
|
||||
const size_t src0_ne0,
|
||||
const size_t src0_nb1,
|
||||
const uint8_t * src1_row,
|
||||
const size_t src1_nb1,
|
||||
float * dst_row,
|
||||
const size_t actual_row_count,
|
||||
const size_t src1_fetch_row_bytes) {
|
||||
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
|
||||
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
|
||||
|
||||
size_t i0 = 0;
|
||||
for (; i0 + 1 < actual_row_count; i0 += 2) {
|
||||
auto * src0_row = src0_plane + i0 * src0_nb1;
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res0 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), src0_ne0);
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res1 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_nb1),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), src0_ne0);
|
||||
|
||||
{
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res0);
|
||||
dst_row[i0 + 1] = convert_vector<data_type1>::convert(res1);
|
||||
}
|
||||
}
|
||||
|
||||
if (src1_fetch_row_bytes > 0) {
|
||||
hexagon::l2fetch_row(src1_row + src1_nb1, src1_fetch_row_bytes);
|
||||
}
|
||||
|
||||
if (i0 < actual_row_count) {
|
||||
auto * src0_row = src0_plane + i0 * src0_nb1;
|
||||
auto res = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), src0_ne0);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res);
|
||||
}
|
||||
}
|
||||
|
||||
template <auto _DotFunc, bool _IsSrcQuantized>
|
||||
inline void mul_mat_impl(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
hexagon::tensor * dst,
|
||||
hexagon::compute_params * params) {
|
||||
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
|
||||
|
||||
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0);
|
||||
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
|
||||
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
|
||||
if (_IsSrcQuantized && dequantize_row_func == nullptr) {
|
||||
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
|
||||
return;
|
||||
}
|
||||
|
||||
const auto r02 = src1->get_ne(2) / src0->get_ne(2);
|
||||
const auto r03 = src1->get_ne(3) / src0->get_ne(3);
|
||||
const auto total_planes = dst->get_ne(3) * dst->get_ne(2);
|
||||
|
||||
auto start_end_plane = std::pair<int64_t, int64_t>{ 0, total_planes };
|
||||
auto start_end_row = std::pair<int64_t, int64_t>{ 0, dst->get_ne(1) };
|
||||
auto start_end_element = std::pair<int64_t, int64_t>{ 0, dst->get_ne(0) };
|
||||
|
||||
if (total_planes >= params->get_thread_count()) {
|
||||
start_end_plane = params->get_work_slice(total_planes);
|
||||
} else if (dst->get_ne(0) >= params->get_thread_count()) {
|
||||
start_end_element = params->get_work_slice(dst->get_ne(0));
|
||||
} else {
|
||||
start_end_row = params->get_work_slice(dst->get_ne(1));
|
||||
}
|
||||
|
||||
if (start_end_plane.second <= start_end_plane.first || start_end_row.second <= start_end_row.first ||
|
||||
start_end_element.second <= start_end_element.first) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl: no work to do, start_end_plane: (%lld, %lld), start_end_row: (%lld, %lld), "
|
||||
"start_end_element: (%lld, %lld)\n",
|
||||
start_end_plane.first, start_end_plane.second, start_end_row.first, start_end_row.second,
|
||||
start_end_element.first, start_end_element.second);
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
|
||||
|
||||
// cache the src0 plane in VTCM
|
||||
size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first;
|
||||
size_t src0_plane_cache_size = 0;
|
||||
uint8_t * src0_plane_read_cache_ptr = nullptr;
|
||||
uint8_t * src0_plane_write_cache_ptr = nullptr;
|
||||
const uint8_t * last_write_cached_plane_ptr = nullptr;
|
||||
const uint8_t * last_read_cached_plane_ptr = nullptr;
|
||||
if constexpr (_IsSrcQuantized) {
|
||||
src0_plane_slice_row_count =
|
||||
std::min(params->get_vtcm_quota_size() / src0_actual_row_size, src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size);
|
||||
if (src0_plane_read_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_ERROR(
|
||||
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
|
||||
"src0_actual_row_size: %zu, will fallback to mem cache\n",
|
||||
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
src0_plane_slice_row_count =
|
||||
std::min(params->get_vtcm_quota_size() / (src0_actual_row_size * 2), src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2);
|
||||
if (src0_plane_read_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_ERROR(
|
||||
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
|
||||
"src0_actual_row_size: %zu, will fallback to mem cache\n",
|
||||
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size);
|
||||
return;
|
||||
}
|
||||
|
||||
src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size;
|
||||
|
||||
const auto i3 = start_end_plane.first / dst->get_ne(2);
|
||||
const auto i2 = start_end_plane.first - i3 * dst->get_ne(2);
|
||||
const uint8_t * src0_plane = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) +
|
||||
start_end_element.first * src0->get_nb(1);
|
||||
const int64_t next_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - start_end_element.first); // number of rows in this slice
|
||||
if (!params->initiate_dma_plane_transfer(src0_plane, src0_plane_write_cache_ptr,
|
||||
src0_actual_row_size, // TODO: reduce to aligned valid_row0_bytes?
|
||||
next_row_count, src0_actual_row_size, src0_actual_row_size)) {
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: failed to continue dma transfer for src0 plane\n");
|
||||
return;
|
||||
}
|
||||
|
||||
last_write_cached_plane_ptr = src0_plane;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%d]mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, total_planes: %lld, "
|
||||
"start_end_plane: "
|
||||
"[%d,%d), start_end_row: [%d,%d), start_end_element: [%d,%d), is_quantized: %d, vtcm_mem: %p(%zu)\n",
|
||||
(int) params->get_thread_index(), src0_actual_row_size, src0_plane_slice_row_count, total_planes,
|
||||
(int) start_end_plane.first, (int) start_end_plane.second, (int) start_end_row.first,
|
||||
(int) start_end_row.second, (int) start_end_element.first, (int) start_end_element.second, _IsSrcQuantized,
|
||||
(void *) src0_plane_read_cache_ptr, params->get_vtcm_quota_size());
|
||||
|
||||
const size_t valid_row1_bytes =
|
||||
src0->get_ne(0) * sizeof(data_type1); // src0 and src1 should have the same element count in the 1st dimension
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
|
||||
|
||||
uint8_t * dst_ptr = dst->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("[%d]mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(int) params->get_thread_index(), (void *) dst, hexagon::get_type_name(dst->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) {
|
||||
const auto i3 = ip / dst->get_ne(2);
|
||||
const auto i2 = ip - i3 * dst->get_ne(2);
|
||||
const auto * src1_plane = src1_ptr + i3 * src1->get_nb(3) + i2 * src1->get_nb(2);
|
||||
auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2);
|
||||
const uint8_t * src0_plane_base = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2);
|
||||
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
|
||||
col_idx += src0_plane_slice_row_count) {
|
||||
const uint8_t * src0_plane = src0_plane_base + col_idx * src0->get_nb(1);
|
||||
const int64_t actual_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - col_idx); // number of rows in this slice
|
||||
if constexpr (_IsSrcQuantized) {
|
||||
if (last_write_cached_plane_ptr != src0_plane) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
|
||||
|
||||
for (int64_t ir = 0; ir < actual_row_count; ir++) {
|
||||
auto * src0_row = src0_plane + ir * src0->get_nb(1);
|
||||
if (ir + 1 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
|
||||
}
|
||||
|
||||
auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_size;
|
||||
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
|
||||
src0->get_ne(0), dequant_table);
|
||||
}
|
||||
|
||||
last_write_cached_plane_ptr = src0_plane;
|
||||
}
|
||||
} else {
|
||||
if (last_read_cached_plane_ptr != src0_plane) {
|
||||
std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr);
|
||||
last_read_cached_plane_ptr = src0_plane;
|
||||
params->wait_for_dma();
|
||||
}
|
||||
|
||||
const uint8_t * src0_next_plane = last_write_cached_plane_ptr;
|
||||
int64_t next_row_count = 0;
|
||||
if (col_idx + src0_plane_slice_row_count < start_end_element.second) {
|
||||
const auto next_col_idx = col_idx + src0_plane_slice_row_count;
|
||||
src0_next_plane = src0_plane_base + next_col_idx * src0_actual_row_size;
|
||||
next_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - next_col_idx); // number of rows in this slice
|
||||
} else if (ip + 1 < start_end_plane.second) {
|
||||
// prefetch the next plane's first slice
|
||||
const auto ip_next = ip + 1;
|
||||
const auto i3_next = ip_next / dst->get_ne(2);
|
||||
const auto i2_next = ip_next - i3_next * dst->get_ne(2);
|
||||
const uint8_t * src0_next_plane_base =
|
||||
src0_ptr + i3_next / r03 * src0->get_nb(3) + i2_next / r02 * src0->get_nb(2);
|
||||
src0_next_plane = src0_next_plane_base + start_end_element.first * src0_actual_row_size;
|
||||
next_row_count = std::min<int64_t>(
|
||||
src0_plane_slice_row_count,
|
||||
start_end_element.second - start_end_element.first); // number of rows in this slice
|
||||
}
|
||||
|
||||
if (last_write_cached_plane_ptr != src0_next_plane) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dma);
|
||||
if (!params->initiate_dma_plane_transfer(
|
||||
src0_next_plane, src0_plane_write_cache_ptr,
|
||||
src0_actual_row_size, // TODO: reduce to aligned valid_row0_bytes?
|
||||
next_row_count, src0_actual_row_size, src0_actual_row_size)) {
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: failed to continue dma transfer for src0 plane\n");
|
||||
return;
|
||||
}
|
||||
|
||||
last_write_cached_plane_ptr = src0_next_plane;
|
||||
}
|
||||
}
|
||||
|
||||
if (start_end_row.second > start_end_row.first) {
|
||||
hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_row1_bytes);
|
||||
}
|
||||
|
||||
for (int64_t i1 = start_end_row.first; i1 < start_end_row.second; i1++) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot);
|
||||
auto * src1_row = src1_plane + i1 * src1->get_nb(1);
|
||||
auto * dst_row = reinterpret_cast<float *>(dst_plane + i1 * dst->get_nb(1)) + col_idx;
|
||||
batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_size, src1_row,
|
||||
src1->get_nb(1), dst_row, actual_row_count,
|
||||
(ip + 1 < start_end_plane.second) ? valid_row1_bytes : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
template <auto _DotFunc, bool _IsSrcQuantized>
|
||||
inline void mul_mat_gemv_impl(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
hexagon::tensor * dst,
|
||||
hexagon::compute_params * params) {
|
||||
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
|
||||
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
|
||||
|
||||
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0);
|
||||
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
|
||||
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
|
||||
if (_IsSrcQuantized && dequantize_row_func == nullptr) {
|
||||
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
|
||||
return;
|
||||
}
|
||||
|
||||
auto start_end_element = std::pair<int64_t, int64_t>{ 0, dst->get_ne(0) };
|
||||
if (dst->get_ne(0) >= params->get_thread_count()) {
|
||||
start_end_element = params->get_work_slice(dst->get_ne(0));
|
||||
} else {
|
||||
DEVICE_LOG_ERROR("Unsupported src1 tensor shape for gemv: %s, ne: %lldx%lldx%lldx%lld\n",
|
||||
hexagon::get_type_name(src1->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2),
|
||||
src1->get_ne(3));
|
||||
return;
|
||||
}
|
||||
|
||||
if (start_end_element.second <= start_end_element.first) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_gemv_impl: no work to do, start_end_plane: [0, 1), start_end_row: [0, 1), "
|
||||
"start_end_element: [%lld, %lld)\n",
|
||||
start_end_element.first, start_end_element.second);
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation
|
||||
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0);
|
||||
|
||||
// cache the src0 plane in VTCM
|
||||
size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first;
|
||||
size_t src0_plane_cache_size = 0;
|
||||
uint8_t * src0_plane_read_cache_ptr = nullptr;
|
||||
uint8_t * src0_plane_write_cache_ptr = nullptr;
|
||||
const auto src1_actual_row_size = hexagon::get_aligned_size(src1->get_nb(1));
|
||||
uint8_t * src1_row_cache_ptr = nullptr;
|
||||
if constexpr (_IsSrcQuantized) {
|
||||
src0_plane_slice_row_count = std::min(
|
||||
(params->get_vtcm_quota_size() - src1_actual_row_size) / src0_actual_row_size, src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size + src1_actual_row_size);
|
||||
if (src0_plane_read_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_ERROR(
|
||||
"mul_mat_gemv_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
|
||||
"src0_actual_row_size: %zu, will fallback to mem cache\n",
|
||||
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size);
|
||||
return;
|
||||
}
|
||||
|
||||
src1_row_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size;
|
||||
} else {
|
||||
src0_plane_slice_row_count =
|
||||
std::min((params->get_vtcm_quota_size() - src1_actual_row_size) / (src0_actual_row_size * 2),
|
||||
src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_actual_row_size);
|
||||
if (src0_plane_read_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to get VTCM cache for src1, size: %zu\n", src1_actual_row_size);
|
||||
return;
|
||||
}
|
||||
|
||||
src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size;
|
||||
src1_row_cache_ptr = src0_plane_write_cache_ptr + src0_plane_cache_size;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_gemv_impl: src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
|
||||
"%p(%zu)\n",
|
||||
src0_actual_row_size, src0_plane_slice_row_count, _IsSrcQuantized, (void *) src0_plane_read_cache_ptr,
|
||||
src0_plane_cache_size);
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
|
||||
|
||||
uint8_t * dst_ptr = dst->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("mul_mat_gemv_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst,
|
||||
hexagon::get_type_name(dst->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
|
||||
{
|
||||
if (!params->initiate_dma_row_transfer(src1_ptr, src1_row_cache_ptr, src1->get_ne(0) * sizeof(data_type1))) {
|
||||
DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to initiate dma transfer for src1\n");
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (!_IsSrcQuantized) {
|
||||
const uint8_t * src0_plane = src0_ptr + start_end_element.first * src0_actual_row_size;
|
||||
const int64_t next_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - start_end_element.first); // number of rows in this slice
|
||||
params->wait_for_dma();
|
||||
if (!params->initiate_dma_plane_transfer(src0_plane, src0_plane_write_cache_ptr, valid_row0_bytes,
|
||||
next_row_count, src0_actual_row_size, src0_actual_row_size)) {
|
||||
DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to initiate dma transfer for src0 plane\n");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
params->wait_for_dma();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
|
||||
col_idx += src0_plane_slice_row_count) {
|
||||
const uint8_t * src0_plane = src0_ptr + col_idx * src0->get_nb(1);
|
||||
const int64_t actual_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - col_idx); // number of rows in this slice
|
||||
if constexpr (_IsSrcQuantized) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
|
||||
|
||||
for (int64_t ir = 0; ir < actual_row_count; ir++) {
|
||||
auto * src0_row = src0_plane + ir * src0->get_nb(1);
|
||||
if (ir + 1 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
|
||||
}
|
||||
|
||||
auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_size;
|
||||
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
|
||||
src0->get_ne(0), dequant_table);
|
||||
}
|
||||
} else {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dma);
|
||||
std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr);
|
||||
params->wait_for_dma();
|
||||
|
||||
const auto next_col_idx = col_idx + src0_plane_slice_row_count;
|
||||
if (next_col_idx < start_end_element.second) {
|
||||
const uint8_t * src0_next_plane = src0_ptr + next_col_idx * src0_actual_row_size;
|
||||
const int64_t next_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - next_col_idx); // number of rows in this slice
|
||||
if (!params->initiate_dma_plane_transfer(src0_next_plane, src0_plane_write_cache_ptr,
|
||||
valid_row0_bytes, next_row_count, src0_actual_row_size,
|
||||
src0_actual_row_size)) {
|
||||
DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to continue dma transfer for src0 plane\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot);
|
||||
auto * dst_row = reinterpret_cast<float *>(dst_ptr) + col_idx;
|
||||
batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_size,
|
||||
src1_row_cache_ptr, src1->get_nb(1), dst_row, actual_row_count, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
bool is_src_cacheable(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
|
||||
const auto & src0_type_traits = hexagon::get_type_traits(src0.type);
|
||||
if (src0_type_traits.to_float == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) cannot be cached, to_float is null\n",
|
||||
hexagon::get_type_name(src0.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota();
|
||||
const size_t src0_type_size =
|
||||
src0_type_traits.is_quantized ? sizeof(hexagon::dequant_output_type) : src0_type_traits.type_size;
|
||||
const auto & src1_type_traits = hexagon::get_type_traits(src1.type);
|
||||
const bool is_gemv = src1.ne[1] == 1 && src1.ne[2] == 1 && src1.ne[3] == 1;
|
||||
size_t min_cache_size = is_gemv ? (src1.ne[0] * src1_type_traits.type_size) : 0;
|
||||
min_cache_size += src0.ne[0] * src0_type_size;
|
||||
if (min_cache_size > vtcm_thread_quota_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) min_cache_size is too large: %ld, vtcm_thread_quota_size: %zu\n",
|
||||
hexagon::get_type_name(src0.type), (long) min_cache_size, vtcm_thread_quota_size);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
|
||||
if (src1.type != NPU_DATA_TYPE_F32 && src1.type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src1 is not F32\n",
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto type_traits = hexagon::get_type_traits(src0.type);
|
||||
if (!type_traits.is_quantized || type_traits.to_float == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src0 is not quantized\n",
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0.ne[0] % type_traits.blck_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) ne[0] is not aligned: %ld\n", hexagon::get_type_name(src0.type),
|
||||
(long) src0.ne[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_src_cacheable(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]supported quantized src0.type(%s) and src1.type(%s)\n",
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
bool is_src0_cached,
|
||||
bool is_src1_cached) {
|
||||
const auto * src1_ptr = is_src1_cached ? nullptr : src1->get_read_buffer_as<float>();
|
||||
const auto * src0_ptr = is_src0_cached ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>();
|
||||
|
||||
if (!hexagon::is_f16_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT][f16_f32]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT][f16_f32]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f16_f16_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<npu_device_fp16_t>();
|
||||
const auto * src0_ptr = is_src0_quantized ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>();
|
||||
|
||||
if (!hexagon::is_f16_f16_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT][f16_f16]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_src0_quantized && !hexagon::is_size_aligned(src0->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0 tensor nb[1] is not aligned: %zu\n", src0->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_size_aligned(src1->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src1 tensor nb[1] is not aligned: %zu\n", src1->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<float>();
|
||||
const auto * src0_ptr = src0->get_read_buffer_as<float>();
|
||||
|
||||
if (!hexagon::is_f32_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT][f32_f32]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_size_aligned(src0->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0 tensor nb[1] is not aligned: %zu\n", src0->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_size_aligned(src1->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src1 tensor nb[1] is not aligned: %zu\n", src1->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
typedef void (*mul_mat_func_type)(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
hexagon::tensor * dst,
|
||||
hexagon::compute_params * params);
|
||||
|
||||
constexpr const size_t kMulMatGemvBaseIndex = 2;
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF32F32Funcs[4] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 gemv
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16F32QuantizedFuncs[4] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F16 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F16 * F32 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F16 * F32 quantized unaligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F16 * F32 quantized aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16F32Funcs[4] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f16_f32, false>, // F16 * F32 unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, false>, // F16 * F32 aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf32_f16_f32, false>, // F16 * F32 unaligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, false>, // F16 * F32 aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16QuantizedFuncs[4] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, true>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized gemv
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16Funcs[4] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 gemv
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
|
||||
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "mul_mat_f32 requires max dims 4");
|
||||
static_assert(std::is_same<hexagon::dequant_output_type, float>::value ||
|
||||
std::is_same<hexagon::dequant_output_type, npu_device_fp16_t>::value,
|
||||
"dequant_output_type must be float or npu_device_fp16_t");
|
||||
|
||||
if (!out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto * src0 = out->get_src(0);
|
||||
auto * src1 = out->get_src(1);
|
||||
if (!src0 || !src1) {
|
||||
return true; // skip if no src
|
||||
}
|
||||
|
||||
const bool is_src0_quantized = is_quantized_type(src0->get_type());
|
||||
const bool is_gemv = src1->get_ne(1) == 1 && src1->get_ne(2) == 1 && src1->get_ne(3) == 1;
|
||||
const auto base_index = is_gemv ? kMulMatGemvBaseIndex : 0;
|
||||
switch (src1->get_type()) {
|
||||
case NPU_DATA_TYPE_F32:
|
||||
if (is_src0_quantized) {
|
||||
kMulMatF16F32QuantizedFuncs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, true, is_gemv) +
|
||||
base_index](src0, src1, out, params);
|
||||
} else if (src0->get_type() == NPU_DATA_TYPE_F16) {
|
||||
kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, true, is_gemv) + base_index](
|
||||
src0, src1, out, params);
|
||||
} else {
|
||||
kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1) + base_index](src0, src1, out,
|
||||
params);
|
||||
}
|
||||
return true;
|
||||
case NPU_DATA_TYPE_F16:
|
||||
if (is_src0_quantized) {
|
||||
kMulMatF16QuantizedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) +
|
||||
base_index](src0, src1, out, params);
|
||||
} else {
|
||||
kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) + base_index](
|
||||
src0, src1, out, params);
|
||||
}
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
DEVICE_LOG_ERROR("Unsupported src1 tensor type: %s\n", get_type_name(src1->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_MUL_MAT) {
|
||||
DEVICE_LOG_DEBUG("op is not MUL_MAT: %d\n", op);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!dst || !srcs || src_len < 2) {
|
||||
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst type is not F32: %s\n", op_get_name(op), get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto & src0 = srcs[0];
|
||||
const auto & src1 = srcs[1];
|
||||
if (src0.type != src1.type) {
|
||||
if (src1.type == NPU_DATA_TYPE_F32 && src0.type == NPU_DATA_TYPE_F16) {
|
||||
// F16 * F32 is supported
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch, but src0 is F16 and src1 is F32\n",
|
||||
op_get_name(op), get_type_name(src0.type), get_type_name(src1.type));
|
||||
} else {
|
||||
#ifdef GGML_HEXAGON_ENABLE_QUANTIZED_TENSORS
|
||||
if (!is_quantized_mul_mat_supported(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch and quantized tensors are not supported\n",
|
||||
op_get_name(op), get_type_name(src0.type), get_type_name(src1.type));
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
if (is_transposed_or_permuted(src0.nb)) {
|
||||
// TODO: fix permuted src0
|
||||
DEVICE_LOG_DEBUG("[%s]src0 is transposed or permuted, disabled\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0.ne[0] != src1.ne[0] || src0.ne[1] != dst->ne[0]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[0],
|
||||
(long) src0.ne[1], (long) src1.ne[0], (long) src1.ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1.ne[1] != dst->ne[1] || src1.ne[2] != dst->ne[2] || src1.ne[3] != dst->ne[3]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n", op_get_name(op),
|
||||
(long) src1.ne[2], (long) src1.ne[3], (long) dst->ne[2], (long) dst->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1.ne[2] % src0.ne[2] || src1.ne[3] % src0.ne[3]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 cannot broadcast to src1: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[2],
|
||||
(long) src0.ne[3], (long) src1.ne[2], (long) src1.ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1.ne[1] == 1 && src1.ne[2] == 1 && src1.ne[3] == 1 && dst->ne[0] < hexagon::kMaxThreadCount) {
|
||||
DEVICE_LOG_DEBUG("[%s]src1 is scalar and dst cannot be parallelized: %ld\n", op_get_name(op),
|
||||
(long) dst->ne[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
|
||||
NPU_UNUSED(op);
|
||||
NPU_UNUSED(next_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -12,5 +12,6 @@ bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
bool is_mul_mat_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -417,4 +417,10 @@ bool is_rope_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
return true; // ROPE operation is not supported yet
|
||||
}
|
||||
|
||||
bool is_rope_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
|
||||
NPU_UNUSED(op);
|
||||
NPU_UNUSED(next_op);
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -9,5 +9,6 @@ bool is_rope_supported(const npu_device_tensor_op_spec * op_spec,
|
|||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
bool is_rope_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -53,6 +53,29 @@ struct compute_params {
|
|||
size_t get_thread_count() const { return thread_params->tcnt; }
|
||||
|
||||
size_t get_thread_index() const { return thread_params->tidx; }
|
||||
|
||||
bool initiate_dma_row_transfer(const uint8_t * src, uint8_t * dst, size_t size) {
|
||||
return thread_params->initiate_dma_row_transfer(src, dst, size);
|
||||
}
|
||||
|
||||
bool initiate_dma_row_transfer(const uint8_t * src0,
|
||||
uint8_t * dst0,
|
||||
const uint8_t * src1,
|
||||
uint8_t * dst1,
|
||||
size_t size) {
|
||||
return thread_params->initiate_dma_row_transfer(src0, dst0, src1, dst1, size);
|
||||
}
|
||||
|
||||
bool initiate_dma_plane_transfer(const uint8_t * src,
|
||||
uint8_t * dst,
|
||||
size_t width,
|
||||
size_t height,
|
||||
size_t src_stride,
|
||||
size_t dst_stride) {
|
||||
return thread_params->initiate_dma_plane_transfer(src, dst, width, height, src_stride, dst_stride);
|
||||
}
|
||||
|
||||
void wait_for_dma() { thread_params->wait_for_dma(); }
|
||||
};
|
||||
|
||||
typedef bool (*compute_func_type)(tensor * dst, compute_params * params);
|
||||
|
|
@ -60,5 +83,6 @@ typedef bool (*op_is_supported_func_type)(const npu_device_tensor_op_spec * op_s
|
|||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
typedef bool (*op_required_sync_func_type)(const npu_device_tensor_op op, const npu_device_tensor_op next_op);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -1,684 +0,0 @@
|
|||
#include "op_mul_mat.hpp"
|
||||
|
||||
#include "thread_pool.hpp" // TODO: remove this dependency
|
||||
#include "type_traits.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename _T> struct get_data_type {};
|
||||
|
||||
template <typename _TData0, typename _TData1>
|
||||
struct get_data_type<HVX_Vector (*)(const _TData0 *, const _TData1 *, size_t)> {
|
||||
using data_type0 = _TData0;
|
||||
using data_type1 = _TData1;
|
||||
};
|
||||
|
||||
template <typename _TRet> struct convert_vector {};
|
||||
|
||||
template <> struct convert_vector<float> {
|
||||
static float convert(HVX_Vector vec) { return hexagon::get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec)); }
|
||||
};
|
||||
|
||||
template <> struct convert_vector<npu_device_fp16_t> {
|
||||
static float convert(HVX_Vector vec) {
|
||||
HVX_Vector vect = Q6_Vhf_equals_Vqf16(vec);
|
||||
uint16_t i = (vect[0] & 0xffff);
|
||||
return reinterpret_cast<__fp16 &>(i);
|
||||
}
|
||||
};
|
||||
|
||||
template <auto _DotFunc, bool _ShouldCacheSrc0>
|
||||
void mul_mat_impl(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
hexagon::tensor * dst,
|
||||
hexagon::compute_params * params) {
|
||||
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
|
||||
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
|
||||
|
||||
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0);
|
||||
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
|
||||
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
|
||||
if (_ShouldCacheSrc0 && dequantize_row_func == nullptr) {
|
||||
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
|
||||
return;
|
||||
}
|
||||
|
||||
const auto r02 = src1->get_ne(2) / src0->get_ne(2);
|
||||
const auto r03 = src1->get_ne(3) / src0->get_ne(3);
|
||||
const auto total_planes = dst->get_ne(3) * dst->get_ne(2);
|
||||
|
||||
auto start_end_plane = std::pair<int64_t, int64_t>{ 0, total_planes };
|
||||
auto start_end_row = std::pair<int64_t, int64_t>{ 0, dst->get_ne(1) };
|
||||
auto start_end_element = std::pair<int64_t, int64_t>{ 0, dst->get_ne(0) };
|
||||
|
||||
if (total_planes >= params->get_thread_count()) {
|
||||
start_end_plane = params->get_work_slice(total_planes);
|
||||
} else if (dst->get_ne(0) >= params->get_thread_count()) {
|
||||
start_end_element = params->get_work_slice(dst->get_ne(0));
|
||||
} else {
|
||||
start_end_row = params->get_work_slice(dst->get_ne(1));
|
||||
}
|
||||
|
||||
if (start_end_plane.second <= start_end_plane.first || start_end_row.second <= start_end_row.first ||
|
||||
start_end_element.second <= start_end_element.first) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl: no work to do, start_end_plane: (%lld, %lld), start_end_row: (%lld, %lld), "
|
||||
"start_end_element: (%lld, %lld)\n",
|
||||
start_end_plane.first,
|
||||
start_end_plane.second,
|
||||
start_end_row.first,
|
||||
start_end_row.second,
|
||||
start_end_element.first,
|
||||
start_end_element.second);
|
||||
return;
|
||||
}
|
||||
|
||||
// cache the src0 plane in VTCM
|
||||
size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first;
|
||||
size_t src0_plane_cache_size = 0;
|
||||
uint8_t * src0_plane_cache_ptr = nullptr;
|
||||
const uint8_t * last_cached_plane_ptr = nullptr;
|
||||
if constexpr (_ShouldCacheSrc0) {
|
||||
src0_plane_slice_row_count =
|
||||
std::min(params->get_vtcm_quota_size() / src0_actual_row_size, src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
src0_plane_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size);
|
||||
if (src0_plane_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_ERROR(
|
||||
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
|
||||
"src0_actual_row_size: %zu, will fallback to mem cache\n",
|
||||
src0_plane_cache_size,
|
||||
src0_plane_slice_row_count,
|
||||
src0_actual_row_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
|
||||
"%p(%zu)\n",
|
||||
src0_actual_row_size,
|
||||
src0_plane_slice_row_count,
|
||||
_ShouldCacheSrc0,
|
||||
(void *) src0_plane_cache_ptr,
|
||||
src0_plane_cache_size);
|
||||
|
||||
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0);
|
||||
const size_t valid_row1_bytes =
|
||||
src0->get_ne(0) * sizeof(data_type1); // src0 and src1 should have the same element count in the 1st dimension
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
|
||||
|
||||
uint8_t * dst_ptr = dst->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) dst,
|
||||
hexagon::get_type_name(dst->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
|
||||
constexpr bool should_fetch_src0_row = !_ShouldCacheSrc0;
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) {
|
||||
const auto i3 = ip / dst->get_ne(2);
|
||||
const auto i2 = ip - i3 * dst->get_ne(2);
|
||||
const auto * src1_plane = src1_ptr + i3 * src1->get_nb(3) + i2 * src1->get_nb(2);
|
||||
auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2);
|
||||
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
|
||||
col_idx += src0_plane_slice_row_count) {
|
||||
const uint8_t * src0_plane =
|
||||
src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1);
|
||||
hexagon::l2fetch_row(src0_plane, src0->get_nb(1));
|
||||
|
||||
const int64_t actual_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - col_idx); // number of rows in this slice
|
||||
if constexpr (_ShouldCacheSrc0) {
|
||||
if (last_cached_plane_ptr != src0_plane) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
|
||||
|
||||
for (int64_t ir = 0; ir < actual_row_count; ir++) {
|
||||
auto * src0_row = src0_plane + ir * src0->get_nb(1);
|
||||
if (ir + 1 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
|
||||
}
|
||||
|
||||
auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size;
|
||||
dequantize_row_func(src0_row,
|
||||
reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
|
||||
src0->get_ne(0),
|
||||
dequant_table);
|
||||
}
|
||||
|
||||
last_cached_plane_ptr = src0_plane;
|
||||
}
|
||||
|
||||
src0_plane = src0_plane_cache_ptr;
|
||||
}
|
||||
|
||||
if (start_end_row.second > start_end_row.first) {
|
||||
hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_row1_bytes);
|
||||
}
|
||||
|
||||
for (int64_t i1 = start_end_row.first; i1 < start_end_row.second; i1++) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot);
|
||||
auto * src1_row = src1_plane + i1 * src1->get_nb(1);
|
||||
auto * dst_row = reinterpret_cast<float *>(dst_plane + i1 * dst->get_nb(1)) + col_idx;
|
||||
int64_t i0 = 0;
|
||||
for (; i0 + 1 < actual_row_count; i0 += 2) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
if constexpr (should_fetch_src0_row) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res0 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row),
|
||||
(size_t) src0->get_ne(0));
|
||||
|
||||
if (should_fetch_src0_row && i0 + 2 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res1 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_actual_row_size),
|
||||
reinterpret_cast<const data_type1 *>(src1_row),
|
||||
(size_t) src0->get_ne(0));
|
||||
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res0);
|
||||
dst_row[i0 + 1] = convert_vector<data_type1>::convert(res1);
|
||||
}
|
||||
}
|
||||
|
||||
if (ip + 1 < start_end_plane.second) {
|
||||
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row1_bytes);
|
||||
}
|
||||
|
||||
if (i0 < actual_row_count) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
auto res = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row),
|
||||
(size_t) src0->get_ne(0));
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
template <auto _DotFunc, bool _ShouldCacheSrc0>
|
||||
void mul_mat_gemv_impl(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
hexagon::tensor * dst,
|
||||
hexagon::compute_params * params) {
|
||||
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
|
||||
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
|
||||
|
||||
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0);
|
||||
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
|
||||
auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table;
|
||||
if (_ShouldCacheSrc0 && dequantize_row_func == nullptr) {
|
||||
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
|
||||
return;
|
||||
}
|
||||
|
||||
auto start_end_element = std::pair<int64_t, int64_t>{ 0, dst->get_ne(0) };
|
||||
if (dst->get_ne(0) >= params->get_thread_count()) {
|
||||
start_end_element = params->get_work_slice(dst->get_ne(0));
|
||||
} else {
|
||||
DEVICE_LOG_ERROR("Unsupported src1 tensor shape for gemv: %s, ne: %lldx%lldx%lldx%lld\n",
|
||||
hexagon::get_type_name(src1->get_type()),
|
||||
src1->get_ne(0),
|
||||
src1->get_ne(1),
|
||||
src1->get_ne(2),
|
||||
src1->get_ne(3));
|
||||
return;
|
||||
}
|
||||
|
||||
if (start_end_element.second <= start_end_element.first) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl: no work to do, start_end_plane: [0, 1), start_end_row: [0, 1), "
|
||||
"start_end_element: [%lld, %lld)\n",
|
||||
start_end_element.first,
|
||||
start_end_element.second);
|
||||
return;
|
||||
}
|
||||
|
||||
// cache the src0 plane in VTCM
|
||||
size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first;
|
||||
size_t src0_plane_cache_size = 0;
|
||||
uint8_t * src0_plane_cache_ptr = nullptr;
|
||||
const auto src1_actual_row_size = hexagon::get_aligned_size(src1->get_nb(1));
|
||||
uint8_t * src1_row_cache_ptr = nullptr;
|
||||
if constexpr (_ShouldCacheSrc0) {
|
||||
src0_plane_slice_row_count = std::min(
|
||||
(params->get_vtcm_quota_size() - src1_actual_row_size) / src0_actual_row_size, src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
src0_plane_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size + src1_actual_row_size);
|
||||
if (src0_plane_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_ERROR(
|
||||
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
|
||||
"src0_actual_row_size: %zu, will fallback to mem cache\n",
|
||||
src0_plane_cache_size,
|
||||
src0_plane_slice_row_count,
|
||||
src0_actual_row_size);
|
||||
return;
|
||||
}
|
||||
|
||||
src1_row_cache_ptr = src0_plane_cache_ptr + src0_plane_cache_size;
|
||||
} else {
|
||||
src1_row_cache_ptr = params->get_vtcm_cache(src1_actual_row_size);
|
||||
if (src1_row_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: failed to get VTCM cache for src1, size: %zu\n", src1_actual_row_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
|
||||
"%p(%zu)\n",
|
||||
src0_actual_row_size,
|
||||
src0_plane_slice_row_count,
|
||||
_ShouldCacheSrc0,
|
||||
(void *) src0_plane_cache_ptr,
|
||||
src0_plane_cache_size);
|
||||
|
||||
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0);
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
|
||||
|
||||
uint8_t * dst_ptr = dst->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
|
||||
(void *) dst,
|
||||
hexagon::get_type_name(dst->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector();
|
||||
constexpr bool should_fetch_src0_row = !_ShouldCacheSrc0;
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
|
||||
{
|
||||
memcpy(src1_row_cache_ptr, src1_ptr, src1->get_ne(0) * sizeof(data_type1));
|
||||
src1_ptr = src1_row_cache_ptr;
|
||||
}
|
||||
|
||||
{
|
||||
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
|
||||
col_idx += src0_plane_slice_row_count) {
|
||||
const uint8_t * src0_plane = src0_ptr + col_idx * src0->get_nb(1);
|
||||
hexagon::l2fetch_row(src0_plane, src0->get_nb(1));
|
||||
|
||||
const int64_t actual_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - col_idx); // number of rows in this slice
|
||||
if constexpr (_ShouldCacheSrc0) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
|
||||
|
||||
for (int64_t ir = 0; ir < actual_row_count; ir++) {
|
||||
auto * src0_row = src0_plane + ir * src0->get_nb(1);
|
||||
if (ir + 1 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
|
||||
}
|
||||
|
||||
auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size;
|
||||
dequantize_row_func(src0_row,
|
||||
reinterpret_cast<hexagon::dequant_output_type *>(cached_row_ptr),
|
||||
src0->get_ne(0),
|
||||
dequant_table);
|
||||
}
|
||||
|
||||
src0_plane = src0_plane_cache_ptr;
|
||||
}
|
||||
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot);
|
||||
auto * dst_row = reinterpret_cast<float *>(dst_ptr) + col_idx;
|
||||
int64_t i0 = 0;
|
||||
for (; i0 + 1 < actual_row_count; i0 += 2) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
if constexpr (should_fetch_src0_row) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res0 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_ptr),
|
||||
(size_t) src0->get_ne(0));
|
||||
|
||||
if (should_fetch_src0_row && i0 + 2 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
auto res1 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_actual_row_size),
|
||||
reinterpret_cast<const data_type1 *>(src1_ptr),
|
||||
(size_t) src0->get_ne(0));
|
||||
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res0);
|
||||
dst_row[i0 + 1] = convert_vector<data_type1>::convert(res1);
|
||||
}
|
||||
}
|
||||
|
||||
if (i0 < actual_row_count) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
auto res = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_ptr),
|
||||
(size_t) src0->get_ne(0));
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
bool is_src_cacheable(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
|
||||
const auto & src0_type_traits = hexagon::get_type_traits(src0.type);
|
||||
if (src0_type_traits.to_float == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) cannot be cached, to_float is null\n",
|
||||
hexagon::get_type_name(src0.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota();
|
||||
const size_t src0_type_size =
|
||||
src0_type_traits.is_quantized ? sizeof(hexagon::dequant_output_type) : src0_type_traits.type_size;
|
||||
const auto & src1_type_traits = hexagon::get_type_traits(src1.type);
|
||||
const bool is_gemv = src1.ne[1] == 1 && src1.ne[2] == 1 && src1.ne[3] == 1;
|
||||
size_t min_cache_size = is_gemv ? (src1.ne[0] * src1_type_traits.type_size) : 0;
|
||||
min_cache_size += src0.ne[0] * src0_type_size;
|
||||
if (min_cache_size > vtcm_thread_quota_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) min_cache_size is too large: %ld, vtcm_thread_quota_size: %zu\n",
|
||||
hexagon::get_type_name(src0.type),
|
||||
(long) min_cache_size,
|
||||
vtcm_thread_quota_size);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
|
||||
if (src1.type != NPU_DATA_TYPE_F32 && src1.type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src1 is not F32\n",
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(src1.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto type_traits = hexagon::get_type_traits(src0.type);
|
||||
if (!type_traits.is_quantized || type_traits.to_float == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src0 is not quantized\n",
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(src1.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0.ne[0] % type_traits.blck_size) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[MUL_MAT]src0.type(%s) ne[0] is not aligned: %ld\n", hexagon::get_type_name(src0.type), (long) src0.ne[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_src_cacheable(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]supported quantized src0.type(%s) and src1.type(%s)\n",
|
||||
hexagon::get_type_name(src0.type),
|
||||
hexagon::get_type_name(src1.type));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
bool is_src0_cached,
|
||||
bool is_src1_cached) {
|
||||
const auto * src1_ptr = is_src1_cached ? nullptr : src1->get_read_buffer_as<float>();
|
||||
const auto * src0_ptr = is_src0_cached ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>();
|
||||
|
||||
if (!hexagon::is_f16_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f16_f16_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<npu_device_fp16_t>();
|
||||
const auto * src0_ptr = is_src0_quantized ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>();
|
||||
|
||||
if (!hexagon::is_f16_f16_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_src0_quantized && !hexagon::is_size_aligned(src0->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0 tensor nb[1] is not aligned: %zu\n", src0->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_size_aligned(src1->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src1 tensor nb[1] is not aligned: %zu\n", src1->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<float>();
|
||||
const auto * src0_ptr = src0->get_read_buffer_as<float>();
|
||||
|
||||
if (!hexagon::is_f32_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_size_aligned(src0->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0 tensor nb[1] is not aligned: %zu\n", src0->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_size_aligned(src1->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src1 tensor nb[1] is not aligned: %zu\n", src1->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
typedef void (*mul_mat_func_type)(hexagon::tensor * src0,
|
||||
hexagon::tensor * src1,
|
||||
hexagon::tensor * dst,
|
||||
hexagon::compute_params * params);
|
||||
|
||||
constexpr const size_t kMulMatGemvBaseIndex = 2;
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF32F32CachedFuncs[4] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized gemv
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF32F32Funcs[4] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 quantized gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 quantized gemv
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16F32Funcs[4] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F32 * F32 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F32 * F32 quantized aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16CachedFuncs[4] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, true>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized gemv
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16Funcs[4] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 quantized aligned
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 quantized gemv
|
||||
mul_mat_gemv_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 quantized gemv
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
|
||||
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "mul_mat_f32 requires max dims 4");
|
||||
static_assert(std::is_same<hexagon::dequant_output_type, float>::value ||
|
||||
std::is_same<hexagon::dequant_output_type, npu_device_fp16_t>::value,
|
||||
"dequant_output_type must be float or npu_device_fp16_t");
|
||||
|
||||
if (!out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto * src0 = out->get_src(0);
|
||||
auto * src1 = out->get_src(1);
|
||||
if (!src0 || !src1) {
|
||||
return true; // skip if no src
|
||||
}
|
||||
|
||||
const bool is_src0_quantized = is_quantized_type(src0->get_type());
|
||||
const bool should_cache_src0 = is_src0_quantized || src1->get_ne(1) > 1;
|
||||
const bool is_gemv = src1->get_ne(1) == 1 && src1->get_ne(2) == 1 && src1->get_ne(3) == 1;
|
||||
const auto base_index = is_gemv ? kMulMatGemvBaseIndex : 0;
|
||||
switch (src1->get_type()) {
|
||||
case NPU_DATA_TYPE_F32:
|
||||
if (is_src0_quantized || src0->get_type() == NPU_DATA_TYPE_F16) {
|
||||
kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, is_src0_quantized, is_gemv) +
|
||||
base_index](src0, src1, out, params);
|
||||
} else if (should_cache_src0) {
|
||||
kMulMatF32F32CachedFuncs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1) + base_index](
|
||||
src0, src1, out, params);
|
||||
} else {
|
||||
kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1) + base_index](
|
||||
src0, src1, out, params);
|
||||
}
|
||||
return true;
|
||||
case NPU_DATA_TYPE_F16:
|
||||
if (should_cache_src0) {
|
||||
kMulMatF16CachedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) +
|
||||
base_index](src0, src1, out, params);
|
||||
} else {
|
||||
kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized) + base_index](
|
||||
src0, src1, out, params);
|
||||
}
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
DEVICE_LOG_ERROR("Unsupported src1 tensor type: %s\n", get_type_name(src1->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_mul_mat_supported(const npu_device_tensor_op_spec * op_spec,
|
||||
const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
const auto op = op_spec->op;
|
||||
if (op != NPU_OP_MUL_MAT) {
|
||||
DEVICE_LOG_DEBUG("op is not MUL_MAT: %d\n", op);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!dst || !srcs || src_len < 2) {
|
||||
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst type is not F32: %s\n", op_get_name(op), get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto & src0 = srcs[0];
|
||||
const auto & src1 = srcs[1];
|
||||
if (src0.type != src1.type) {
|
||||
if (src1.type == NPU_DATA_TYPE_F32 && src0.type == NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch, but src0 is F16 and src1 is F32\n",
|
||||
op_get_name(op),
|
||||
get_type_name(src0.type),
|
||||
get_type_name(src1.type));
|
||||
return true; // F16 * F32 is supported
|
||||
}
|
||||
|
||||
#ifdef GGML_HEXAGON_ENABLE_QUANTIZED_TENSORS
|
||||
if (!is_quantized_mul_mat_supported(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch and quantized tensors are not supported\n",
|
||||
op_get_name(op),
|
||||
get_type_name(src0.type),
|
||||
get_type_name(src1.type));
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
if (src0.ne[0] != src1.ne[0] || src0.ne[1] != dst->ne[0]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n",
|
||||
op_get_name(op),
|
||||
(long) src0.ne[0],
|
||||
(long) src0.ne[1],
|
||||
(long) src1.ne[0],
|
||||
(long) src1.ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1.ne[1] != dst->ne[1] || src1.ne[2] != dst->ne[2] || src1.ne[3] != dst->ne[3]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n",
|
||||
op_get_name(op),
|
||||
(long) src1.ne[2],
|
||||
(long) src1.ne[3],
|
||||
(long) dst->ne[2],
|
||||
(long) dst->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1.ne[2] % src0.ne[2] || src1.ne[3] % src0.ne[3]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 cannot broadcast to src1: %ldx%ld vs %ldx%ld\n",
|
||||
op_get_name(op),
|
||||
(long) src0.ne[2],
|
||||
(long) src0.ne[3],
|
||||
(long) src1.ne[2],
|
||||
(long) src1.ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -111,8 +111,8 @@ class tensor {
|
|||
|
||||
npu_device_tensor_data_type get_type() const { return _info.type; }
|
||||
|
||||
const uint8_t * get_read_buffer() const {
|
||||
if (!_info.is_constant && _has_modified) {
|
||||
const uint8_t * get_read_buffer(const bool force_invalidate = false) const {
|
||||
if (force_invalidate || (!_info.is_constant && _has_modified)) {
|
||||
invalidate();
|
||||
const_cast<tensor *>(this)->_has_modified = false; // TODO: avoid const_cast
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "dma_transfer.hpp"
|
||||
#include "util.hpp"
|
||||
#include "vtcm_mem.hpp"
|
||||
|
||||
|
|
@ -98,6 +99,8 @@ template <size_t _ThreadCount> class thread_pool {
|
|||
|
||||
std::unique_ptr<vtcm_mem> vtcm_cache;
|
||||
|
||||
hexagon::dma::dma_transfer dma;
|
||||
|
||||
void init_vtcm_cache() { vtcm_cache = std::make_unique<vtcm_mem>(vtcm_quota_size, false); }
|
||||
|
||||
uint8_t * get_vtcm_cache(size_t size) {
|
||||
|
|
@ -113,15 +116,39 @@ template <size_t _ThreadCount> class thread_pool {
|
|||
|
||||
return vtcm_cache->get_mem();
|
||||
}
|
||||
|
||||
bool initiate_dma_row_transfer(const uint8_t * src, uint8_t * dst, size_t size) {
|
||||
return dma.submit1d(src, dst, size);
|
||||
}
|
||||
|
||||
bool initiate_dma_row_transfer(const uint8_t * src0,
|
||||
uint8_t * dst0,
|
||||
const uint8_t * src1,
|
||||
uint8_t * dst1,
|
||||
size_t size) {
|
||||
return dma.submit1d(src0, dst0, src1, dst1, size);
|
||||
}
|
||||
|
||||
bool initiate_dma_plane_transfer(const uint8_t * src,
|
||||
uint8_t * dst,
|
||||
size_t width,
|
||||
size_t height,
|
||||
size_t src_stride,
|
||||
size_t dst_stride) {
|
||||
return dma.submit2d(src, dst, width, height, src_stride, dst_stride);
|
||||
}
|
||||
|
||||
void wait_for_dma() { dma.wait(); }
|
||||
};
|
||||
|
||||
typedef void (*task_type)(thread_pool * pool, thread_params * param, void * arg);
|
||||
|
||||
thread_pool() {
|
||||
const auto quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount;
|
||||
for (size_t i = 0; i < kMaxThreadCount; ++i) {
|
||||
auto & thread_param = _thread_params[i];
|
||||
thread_param.tidx = i;
|
||||
thread_param.vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount;
|
||||
thread_param.vtcm_quota_size = quota_size;
|
||||
thread_param.pool = this;
|
||||
|
||||
thread_param.init_vtcm_cache();
|
||||
|
|
@ -143,7 +170,7 @@ template <size_t _ThreadCount> class thread_pool {
|
|||
_threads[i] = std::move(thread);
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("thread_pool.created: %zu\n", kMaxSubThreadCount);
|
||||
DEVICE_LOG_DEBUG("thread_pool.created: %zu, vtcm_quota_size: %zu\n", kMaxSubThreadCount, quota_size);
|
||||
}
|
||||
|
||||
~thread_pool() {
|
||||
|
|
|
|||
|
|
@ -448,8 +448,8 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * d
|
|||
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
|
||||
|
||||
static const auto load_masks = make_quad_block_mask<npu_device_block_q4_0>();
|
||||
static const HVX_Vector scale_indices __attribute__((aligned(hexagon::kBytesPerVector))) =
|
||||
static const auto load_masks = make_quad_block_mask<npu_device_block_q4_0>();
|
||||
alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices =
|
||||
make_scale_load_mask<npu_device_block_q4_0>();
|
||||
|
||||
const int nb = count / qk;
|
||||
|
|
@ -538,10 +538,10 @@ HVX_Vector load_dequant_table_q4_0() {
|
|||
static_assert(kTableSize <= hexagon::kBytesPerVector / sizeof(__fp16), "table too large");
|
||||
|
||||
static const HVX_Vector result = []() -> HVX_Vector {
|
||||
union {
|
||||
alignas(hexagon::kBytesPerVector) union {
|
||||
HVX_Vector v;
|
||||
__fp16 f16[sizeof(HVX_Vector) / sizeof(__fp16)];
|
||||
} table __attribute__((aligned(hexagon::kBytesPerVector)));
|
||||
} table;
|
||||
|
||||
table.v = Q6_V_vzero();
|
||||
for (int i = 0; i < kTableSize; ++i) {
|
||||
|
|
@ -567,10 +567,10 @@ HVX_Vector load_dequant_table_q4_k() {
|
|||
static_assert(kTableSize <= hexagon::kBytesPerVector / sizeof(__fp16), "table too large");
|
||||
|
||||
const static HVX_Vector result = []() -> HVX_Vector {
|
||||
union {
|
||||
alignas(hexagon::kBytesPerVector) union {
|
||||
HVX_Vector v;
|
||||
__fp16 f16[sizeof(HVX_Vector) / sizeof(__fp16)];
|
||||
} table __attribute__((aligned(hexagon::kBytesPerVector)));
|
||||
} table;
|
||||
|
||||
table.v = Q6_V_vzero();
|
||||
for (int i = 0; i < kTableSize; ++i) {
|
||||
|
|
@ -591,10 +591,10 @@ void dequantize_row_q4_K(const void * src, hexagon::dequant_output_type * dst, s
|
|||
|
||||
const HVX_VectorPred scale_mask = Q6_Q_vsetq_R(hexagon::kBytesPerVector / 2);
|
||||
|
||||
union {
|
||||
alignas(hexagon::kBytesPerVector * 4) union {
|
||||
HVX_VectorPair p[2];
|
||||
HVX_Vector v[4];
|
||||
} dual_pair __attribute__((aligned(hexagon::kBytesPerVector * 4)));
|
||||
} dual_pair;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * q = src_ptr[i].qs;
|
||||
|
|
@ -674,40 +674,21 @@ void copy_row_f32(const void * src, hexagon::dequant_output_type * dst, size_t c
|
|||
}
|
||||
|
||||
constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
|
||||
{ NPU_DATA_TYPE_F32,
|
||||
"F32", 1,
|
||||
sizeof(float),
|
||||
false, copy_row_f32,
|
||||
nullptr, hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
|
||||
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32, nullptr,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f32_f32>,
|
||||
hexagon::type_erase_dot_func<hexagon::is_f32_f32_dot_product_aligned> },
|
||||
{ NPU_DATA_TYPE_F16,
|
||||
"F16", 1,
|
||||
sizeof(npu_device_fp16_t),
|
||||
false, copy_row_f16,
|
||||
quantize_row_fp16, hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16>,
|
||||
{ NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, copy_row_f16, quantize_row_fp16,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16>,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f16_f16>,
|
||||
hexagon::type_erase_dot_func<hexagon::is_f16_f16_dot_product_aligned> },
|
||||
{ NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false },
|
||||
{ NPU_DATA_TYPE_Q8_0,
|
||||
"Q8_0", QUANT_BLOCK_SIZE,
|
||||
sizeof(npu_device_block_q8_0),
|
||||
true, dequantize_row_q8_0,
|
||||
{ NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0,
|
||||
quantize_row_q8_0 },
|
||||
{ NPU_DATA_TYPE_Q4_0,
|
||||
"Q4_0", QUANT_BLOCK_SIZE,
|
||||
sizeof(npu_device_block_q4_0),
|
||||
true, dequantize_row_q4_0,
|
||||
quantize_row_q4_0, nullptr,
|
||||
nullptr, nullptr,
|
||||
load_dequant_table_q4_0 },
|
||||
{ NPU_DATA_TYPE_Q4_K,
|
||||
"Q4_K", QUANT_K_BLOCK_SIZE,
|
||||
sizeof(npu_device_block_q4_k),
|
||||
true, dequantize_row_q4_K,
|
||||
quantize_row_q4_K, nullptr,
|
||||
nullptr, nullptr,
|
||||
load_dequant_table_q4_k },
|
||||
{ NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0,
|
||||
quantize_row_q4_0, nullptr, nullptr, nullptr, load_dequant_table_q4_0 },
|
||||
{ NPU_DATA_TYPE_Q4_K, "Q4_K", QUANT_K_BLOCK_SIZE, sizeof(npu_device_block_q4_k), true, dequantize_row_q4_K,
|
||||
quantize_row_q4_K, nullptr, nullptr, nullptr, load_dequant_table_q4_k },
|
||||
};
|
||||
|
||||
static_assert(std::size(kDeviceTypeTraits) == NPU_DATA_TYPE_COUNT,
|
||||
|
|
|
|||
|
|
@ -55,32 +55,14 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx) {
|
|||
auto * src1 = op->get_src(1);
|
||||
char buffer[512];
|
||||
if (src1 == nullptr) {
|
||||
snprintf(buffer,
|
||||
sizeof(buffer),
|
||||
"[%s][%lldx%lldx%lldx%lld%s], tidx: %zu",
|
||||
op_get_name(op->get_op()),
|
||||
src0->get_ne(0),
|
||||
src0->get_ne(1),
|
||||
src0->get_ne(2),
|
||||
src0->get_ne(3),
|
||||
get_type_name(src0->get_type()),
|
||||
snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s], tidx: %zu", op_get_name(op->get_op()),
|
||||
src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3), get_type_name(src0->get_type()),
|
||||
tidx);
|
||||
} else {
|
||||
snprintf(buffer,
|
||||
sizeof(buffer),
|
||||
"[%s][%lldx%lldx%lldx%lld%s],[%lldx%lldx%lldx%lld%s], tidx: %zu",
|
||||
op_get_name(op->get_op()),
|
||||
src0->get_ne(0),
|
||||
src0->get_ne(1),
|
||||
src0->get_ne(2),
|
||||
src0->get_ne(3),
|
||||
get_type_name(src0->get_type()),
|
||||
src1->get_ne(0),
|
||||
src1->get_ne(1),
|
||||
src1->get_ne(2),
|
||||
src1->get_ne(3),
|
||||
get_type_name(src1->get_type()),
|
||||
tidx);
|
||||
snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s],[%lldx%lldx%lldx%lld%s], tidx: %zu",
|
||||
op_get_name(op->get_op()), src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3),
|
||||
get_type_name(src0->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2), src1->get_ne(3),
|
||||
get_type_name(src1->get_type()), tidx);
|
||||
}
|
||||
return npu_scoped_timer<1024>(buffer);
|
||||
}
|
||||
|
|
@ -102,8 +84,7 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx) {
|
|||
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(tracker_name, idx, sub_prefix) \
|
||||
hexagon::npu_sub_process_scoped_timer< \
|
||||
std::remove_reference_t<decltype(__npu_op_timer_##tracker_name)>::kBufferCount, \
|
||||
idx> \
|
||||
std::remove_reference_t<decltype(__npu_op_timer_##tracker_name)>::kBufferCount, idx> \
|
||||
__npu_op_sub_timer##sub_prefix(__npu_op_timer_##tracker_name, #sub_prefix)
|
||||
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_power.h>
|
||||
#include <qurt.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
|
@ -95,6 +96,22 @@ inline bool is_same_shape(const npu_device_tensor_spec & src, const npu_device_t
|
|||
return is_same_shape(src.ne, dst.ne);
|
||||
}
|
||||
|
||||
class qurt_mutex {
|
||||
public:
|
||||
qurt_mutex() { qurt_mutex_init(&_mutex); }
|
||||
|
||||
~qurt_mutex() { qurt_mutex_destroy(&_mutex); }
|
||||
|
||||
void lock() { qurt_mutex_lock(&_mutex); }
|
||||
|
||||
void unlock() { qurt_mutex_unlock(&_mutex); }
|
||||
|
||||
private:
|
||||
qurt_mutex_t _mutex;
|
||||
|
||||
DISABLE_COPY_AND_MOVE(qurt_mutex);
|
||||
};
|
||||
|
||||
class power_utils {
|
||||
public:
|
||||
power_utils() {
|
||||
|
|
|
|||
|
|
@ -469,7 +469,7 @@ inline HVX_Vector qhmath_hvx_exp_vhf(HVX_Vector sline) {
|
|||
|
||||
inline HVX_VectorPair_x4 qhmath_load_div_sf_ltu() {
|
||||
/* Coefficients in float representation */
|
||||
constexpr const float c0_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = {
|
||||
alignas(hexagon::kBytesPerVector) constexpr const float c0_coeffs[32] = {
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
|
|
@ -503,7 +503,7 @@ inline HVX_VectorPair_x4 qhmath_load_div_sf_ltu() {
|
|||
2.0993232550771013,
|
||||
2.032425103348979,
|
||||
};
|
||||
constexpr const float c1_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = {
|
||||
alignas(hexagon::kBytesPerVector) constexpr const float c1_coeffs[32] = {
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
|
|
@ -537,7 +537,7 @@ inline HVX_VectorPair_x4 qhmath_load_div_sf_ltu() {
|
|||
-1.6526109646324616,
|
||||
-1.5489652830974667,
|
||||
};
|
||||
constexpr const float c2_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = {
|
||||
alignas(hexagon::kBytesPerVector) constexpr const float c2_coeffs[32] = {
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
|
|
@ -571,7 +571,7 @@ inline HVX_VectorPair_x4 qhmath_load_div_sf_ltu() {
|
|||
0.5781761255999769,
|
||||
0.5246475096790261,
|
||||
};
|
||||
constexpr const float c3_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = {
|
||||
alignas(hexagon::kBytesPerVector) constexpr const float c3_coeffs[32] = {
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
|
|
@ -750,7 +750,7 @@ inline HVX_Vector qhmath_hvx_div_vf(HVX_Vector num, HVX_Vector denom, HVX_Vector
|
|||
|
||||
inline HVX_VectorPair_x4 qhmath_load_div_hf_ltu() {
|
||||
/* Coefficients in float representation */
|
||||
constexpr const float c0_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = {
|
||||
alignas(hexagon::kBytesPerVector) constexpr const float c0_coeffs[32] = {
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
|
|
@ -784,7 +784,7 @@ inline HVX_VectorPair_x4 qhmath_load_div_hf_ltu() {
|
|||
2.0981551828382417,
|
||||
2.0319234960945,
|
||||
};
|
||||
constexpr const float c1_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = {
|
||||
alignas(hexagon::kBytesPerVector) constexpr const float c1_coeffs[32] = {
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
|
|
@ -818,7 +818,7 @@ inline HVX_VectorPair_x4 qhmath_load_div_hf_ltu() {
|
|||
-1.6507730169513504,
|
||||
-1.5482028127706613,
|
||||
};
|
||||
constexpr const float c2_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = {
|
||||
alignas(hexagon::kBytesPerVector) constexpr const float c2_coeffs[32] = {
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
|
|
@ -852,7 +852,7 @@ inline HVX_VectorPair_x4 qhmath_load_div_hf_ltu() {
|
|||
0.5772121720355701,
|
||||
0.524261196551401,
|
||||
};
|
||||
constexpr const float c3_coeffs[32] __attribute__((aligned(hexagon::kBytesPerVector))) = {
|
||||
alignas(hexagon::kBytesPerVector) constexpr const float c3_coeffs[32] = {
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
#include "buffer.hpp"
|
||||
|
||||
#include <rpcmem.h>
|
||||
|
||||
#include "host_device.hpp"
|
||||
#include "profiler.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
#include <rpcmem.h>
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const int kRpcMemDefaultHeapId = RPCMEM_HEAP_ID_SYSTEM;
|
||||
|
|
@ -48,14 +48,20 @@ ggml_status backend_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor
|
|||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
void backend_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset,
|
||||
size_t size) {
|
||||
void backend_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
ggml_tensor * tensor,
|
||||
const void * data,
|
||||
size_t offset,
|
||||
size_t size) {
|
||||
GGML_UNUSED(buffer);
|
||||
memcpy((char *) tensor->data + offset, data, size);
|
||||
}
|
||||
|
||||
void backend_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset,
|
||||
size_t size) {
|
||||
void backend_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
const ggml_tensor * tensor,
|
||||
void * data,
|
||||
size_t offset,
|
||||
size_t size) {
|
||||
GGML_UNUSED(buffer);
|
||||
memcpy(data, (const char *) tensor->data + offset, size);
|
||||
}
|
||||
|
|
@ -135,31 +141,33 @@ host_buffer::host_buffer(common::rpc_mem_ptr allocator, size_t size, uint32_t do
|
|||
_size(size),
|
||||
_domain_id(domain_id) {
|
||||
if (!_allocator->is_valid()) {
|
||||
LOG_ERROR("rpc memory not initialized\n");
|
||||
LOG_ERROR("[hexagon-npu]rpc memory not initialized\n");
|
||||
return;
|
||||
}
|
||||
|
||||
if (size > _allocator->get_max_alloc_size()) {
|
||||
LOG_ERROR("rpc memory size %zu exceeds max alloc size %zu\n", size, _allocator->get_max_alloc_size());
|
||||
LOG_ERROR(
|
||||
"[hexagon-npu]rpc memory size %zu exceeds max alloc size %zu\n", size, _allocator->get_max_alloc_size());
|
||||
return;
|
||||
}
|
||||
|
||||
_data = _allocator->alloc(kRpcMemDefaultHeapId, kRpcMemDefaultFlags, size);
|
||||
if (!_data) {
|
||||
LOG_ERROR("failed to allocate rpc memory, size: %d MB\n", (int) (size / (1 << 20)));
|
||||
LOG_ERROR("[hexagon-npu]failed to allocate rpc memory, size: %d MB\n", (int) (size / (1 << 20)));
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DEBUG("create host_buffer(%p), size: %zu, domain_id: %d\n", (void *) _data, size, (int) domain_id);
|
||||
LOG_DEBUG("[hexagon-npu]create host_buffer(%p), size: %zu, domain_id: %d\n", (void *) _data, size, (int) domain_id);
|
||||
}
|
||||
|
||||
host_buffer::~host_buffer() {
|
||||
LOG_DEBUG("destroy host_buffer(%p), size: %zu, domain_id: %d\n", (void *) _data, _size, (int) _domain_id);
|
||||
LOG_DEBUG(
|
||||
"[hexagon-npu]destroy host_buffer(%p), size: %zu, domain_id: %d\n", (void *) _data, _size, (int) _domain_id);
|
||||
_tensors.clear();
|
||||
if (_buffer_fd != -1) {
|
||||
auto ret = _allocator->fastrpc_munmap((int) _domain_id, _buffer_fd, nullptr, 0);
|
||||
if (ret != AEE_SUCCESS) {
|
||||
LOG_ERROR("failed to munmap rpc memory, fd: %d, ret: %d\n", _buffer_fd, ret);
|
||||
LOG_ERROR("[hexagon-npu]failed to munmap rpc memory, fd: %d, ret: %d\n", _buffer_fd, ret);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -169,31 +177,37 @@ host_buffer::~host_buffer() {
|
|||
|
||||
std::shared_ptr<host_tensor> host_buffer::init_tensor(ggml_tensor * tensor, remote_handle64 device_handle) {
|
||||
if (!_data) {
|
||||
LOG_ERROR("failed to init tensor, rpc memory not initialized\n");
|
||||
LOG_ERROR("[hexagon-npu]failed to init tensor, rpc memory not initialized\n");
|
||||
return std::shared_ptr<host_tensor>();
|
||||
}
|
||||
|
||||
if (_buffer_fd == -1) {
|
||||
_buffer_fd = _allocator->to_fd(_data);
|
||||
if (_buffer_fd < 0) {
|
||||
LOG_ERROR("failed to get fd from rpc memory\n");
|
||||
LOG_ERROR("[hexagon-npu]failed to get fd from rpc memory\n");
|
||||
return std::shared_ptr<host_tensor>();
|
||||
}
|
||||
|
||||
auto ret = _allocator->fastrpc_mmap((int) _domain_id, _buffer_fd, _data, 0, _size, FASTRPC_MAP_FD);
|
||||
if (ret != AEE_SUCCESS) {
|
||||
LOG_ERROR("failed to mmap rpc memory, fd: %d, size: %zu, ret: %d\n", _buffer_fd, _size, ret);
|
||||
LOG_ERROR("[hexagon-npu]failed to mmap rpc memory, fd: %d, size: %zu, ret: %d\n", _buffer_fd, _size, ret);
|
||||
return std::shared_ptr<host_tensor>();
|
||||
}
|
||||
|
||||
LOG_DEBUG("mmap rpc memory(%p), fd: %d, addr: %p, size: %zu\n", (void *) _data, _buffer_fd, _data, _size);
|
||||
LOG_DEBUG("[hexagon-npu]mmap rpc memory(%p), fd: %d, addr: %p, size: %zu\n",
|
||||
(void *) _data,
|
||||
_buffer_fd,
|
||||
_data,
|
||||
_size);
|
||||
}
|
||||
|
||||
auto tensor_object = std::make_shared<host_tensor>(
|
||||
tensor, _buffer_fd, (uint64_t) (reinterpret_cast<uint8_t *>(tensor->data) - reinterpret_cast<uint8_t *>(_data)),
|
||||
tensor,
|
||||
_buffer_fd,
|
||||
(uint64_t) (reinterpret_cast<uint8_t *>(tensor->data) - reinterpret_cast<uint8_t *>(_data)),
|
||||
device_handle);
|
||||
if (!tensor_object->is_valid()) {
|
||||
LOG_ERROR("failed to init tensor, device handle: %p\n", (void *) device_handle);
|
||||
LOG_ERROR("[hexagon-npu]failed to init tensor, device handle: %p\n", (void *) device_handle);
|
||||
return std::shared_ptr<host_tensor>();
|
||||
}
|
||||
|
||||
|
|
@ -202,7 +216,7 @@ std::shared_ptr<host_tensor> host_buffer::init_tensor(ggml_tensor * tensor, remo
|
|||
}
|
||||
|
||||
void host_buffer::clear_tensors() {
|
||||
LOG_DEBUG("clear host_buffer(%p) tensors\n", (void *) _data);
|
||||
LOG_DEBUG("[hexagon-npu]clear host_buffer(%p) tensors\n", (void *) _data);
|
||||
host_tensor::destroy_tensors(_tensors);
|
||||
}
|
||||
|
||||
|
|
@ -230,7 +244,7 @@ size_t host_buffer_type::get_buffer_alignment() const {
|
|||
|
||||
size_t host_buffer_type::get_max_buffer_size() const {
|
||||
if (!_rpc_mem) {
|
||||
LOG_ERROR("rpc memory not initialized\n");
|
||||
LOG_ERROR("[%s]rpc memory not initialized\n", _device->get_name());
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -239,19 +253,19 @@ size_t host_buffer_type::get_max_buffer_size() const {
|
|||
|
||||
ggml_backend_buffer_t host_buffer_type::allocate_buffer(size_t size) {
|
||||
if (!_rpc_mem) {
|
||||
LOG_ERROR("rpc memory not initialized\n");
|
||||
LOG_ERROR("[%s]rpc memory not initialized\n", _device->get_name());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!_device->is_device_initialized()) {
|
||||
LOG_ERROR("device is not initialized\n");
|
||||
LOG_ERROR("[%s]device is not initialized\n", _device->get_name());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto * buffer = new host_buffer(_rpc_mem, size, _device->get_dsp_domain_id());
|
||||
if (!buffer->is_valid()) {
|
||||
delete buffer;
|
||||
LOG_ERROR("Failed to allocate buffer of size %zu\n", size);
|
||||
LOG_ERROR("[%s]Failed to allocate buffer of size %zu\n", _device->get_name(), size);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue