feat: perf opt quant (#47)
* feat: add mixed precision dot product implementation and function declaration * feat: implement mixed precision vector dot product and conversion functions * fix: update data type handling in matrix multiplication implementation * fix: adjust row count handling in matrix multiplication implementation for accurate slicing * fix: optimize matrix multiplication implementation by unroll loop * update performance tracking for matrix multiplication implementation * add fetching * wip * fix: support F16 * F32 multiplication in is_mul_mat_supported function * fix: improve src0 fetching logic in vec_dot_product_mixed_impl for better alignment handling * fix test failure for row width 67 * try fix failed test * fix: rename aligned_address to align_down for clarity in vector alignment handling * wip * qnn fix: update device capabilities for quantized types in qnn-lib to improve compatibility * fix test failure at width == 193 * fix: replace zero vector initialization with previous vector in mixed dot product implementation * wip * fix: improve handling of last vector in mixed dot product implementation * wip * wip * wip * wip * Enhance mul_mat_f32 function to support quantized types and improve static assertions * rename * Refactor dequantization functions to use npu_device_fp16_t and improve type handling * Optimize dequantization in dequantize_row_q8_0 by replacing qf32 multiplication with qf16 * Optimize dequantization in dequantize_row_q4_0 by replacing qf32 multiplication with qf16 * Add hvx_vsf_convert_vhf function for improved vector conversion * add perf logs * Refactor dequantize_row_q4_0 for alignment * Update logging in supports_op_impl and supports_op to use ggml_op_desc for better clarity * Add support for ROPE operation in NPU capabilities and related functions * Implement ROPE operation in tensor and op_rope, including cache initialization and correction dimension calculations * enable ROPE by adding operation validation * add support to freq is null case * wip * Refactor rope_f32 to improve indexing by introducing total_planes calculation * reformat * Refactor rope_f32 to optimize data access patterns by introducing row and plane pointers * Add performance tracking to rope_f32 function for enhanced profiling * Refactor rope_f32 to use a templated implementation * Refactor rope_impl to replace loop with memcpy for improved performance * Refactor mul_mat_impl to support quantization as a template parameter * wip * wip * Refactor rope_impl to optimize plane indexing in the processing loop * Add aligned vector dot product implementation for mixed precision types * wip * Enhance matrix multiplication for F32 and F16 types with alignment checks * Optimize vec_dot_product_mix_aligned_impl for improved performance with additional vector sums * Add alignment checks for matrix multiplication and vector dot products * Refactor matrix multiplication to use function pointers for improved readability and maintainability * Fix alignment check in is_dot_product_aligned to ensure correct vector size handling * Remove unused f16_to_f32_table parameter from quantization and dequantization functions * wip * Add L2 fetch for src1 plane rows in matrix multiplication implementation * wip * Refactor hvx_vsf_convert_vhf to accept an additional parameter for flexibility in vector multiplication * Refactor vec_dot_product_mix_aligned_impl to improve variable naming for clarity * Refactor load_dual_block_generic and dequantize_row_q4_0 to improve performance * Refactor vector operation functions to improve clarity and consistency in variable usage * wip * wip * Refactor dequantize_row_q4_0_impl for improved clarity and performance in vector operations * wip * Update load_dual_block_generic to use intrinsics * Refactor load_dual_block_generic and load_qual_block_generic for improved performance and clarity * wip * wip * Optimize dequantize_row_q8_0 for improved performance by unrolling for loop * wip * wip * fix typo
This commit is contained in:
parent
989772c7bc
commit
a29243e7a4
|
|
@ -39,7 +39,6 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
|
||||
const auto q_to_vec_dot = hexagon::get_type_traits(k->get_type()).from_float; // TODO: fix this
|
||||
const auto kq_vec_dot = hexagon::get_type_traits(k->get_type()).vec_dot;
|
||||
const auto v_to_float = hexagon::get_type_traits(v->get_type()).to_float;
|
||||
if (!q_to_vec_dot || !kq_vec_dot) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n");
|
||||
return;
|
||||
|
|
@ -95,7 +94,6 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
|
||||
float * V32 = VKQ32 + aligned_dv; // (temporary) FP32 V buffer
|
||||
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
|
||||
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
|
||||
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
|
||||
|
|
@ -122,7 +120,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
hexagon::l2fetch_row(q_data + q->get_nb(1), row_bytes_q);
|
||||
}
|
||||
|
||||
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK, params->f16_to_f32_table);
|
||||
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);
|
||||
|
||||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
|
|
@ -192,10 +190,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
|
||||
// V += v*expf(s - M)
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 2, mad);
|
||||
if (v_to_float) {
|
||||
v_to_float(v_data, V32, DV, params->f16_to_f32_table);
|
||||
hexagon::vec_mad_f32(V32, vs, VKQ32, DV);
|
||||
} else {
|
||||
{
|
||||
// V is F32
|
||||
hexagon::vec_mad_f32(reinterpret_cast<const float *>(v_data), vs, VKQ32, DV);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include "op_flash_attn.hpp"
|
||||
#include "op_mul_mat.hpp"
|
||||
#include "op_rope.hpp"
|
||||
#include "type_traits.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
|
|
@ -62,7 +63,7 @@ inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count
|
|||
(leftover_bytes + hexagon::unaligned_bytes(iptr1) > hexagon::kBytesPerVector) ? *iptr1 : prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
q6op_vstu_variable_ARV(optr, leftover_bytes, _OpIntrinsic(curr0, curr1));
|
||||
hexagon::q6op_vstu_variable_ARV(optr, leftover_bytes, _OpIntrinsic(curr0, curr1));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -179,16 +180,6 @@ template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::co
|
|||
return true;
|
||||
}
|
||||
|
||||
bool is_same_shape(const npu_device_tensor_spec & src, const npu_device_tensor_spec & dst) {
|
||||
for (size_t i = 0; i < DEVICE_TENSOR_MAX_DIMS; ++i) {
|
||||
if (src.ne[i] != dst.ne[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len) {
|
||||
if (op != NPU_OP_ADD && op != NPU_OP_SUB && op != NPU_OP_MUL) {
|
||||
|
|
@ -228,7 +219,7 @@ bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tens
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!is_same_shape(src0, *dst)) {
|
||||
if (!hexagon::is_same_shape(src0, *dst)) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -271,7 +262,7 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
|
|||
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes));
|
||||
}
|
||||
|
||||
const float mean = hexagon::vec_reduction_qf32_f32(sum) / count; // TODO: figure out how to do division in vector
|
||||
const float mean = hexagon::vec_reduction_f32_qf32(sum) / count; // TODO: figure out how to do division in vector
|
||||
const float scale = 1.0f / sqrtf(mean + eps); // TODO: use buildin blas sqrtf?
|
||||
hexagon::vec_scale_f32(src, scale, dst, count);
|
||||
}
|
||||
|
|
@ -354,7 +345,7 @@ bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!is_same_shape(src0, *dst)) {
|
||||
if (!hexagon::is_same_shape(src0, *dst)) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -396,7 +387,7 @@ constexpr const op_capabilities kOpCapabilities[] = {
|
|||
{
|
||||
element_wise_op<vec_op_f32_f32<vmul_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vmul_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, false, // requires_thread_barrier
|
||||
}, false, // requires_thread_barrier
|
||||
},
|
||||
{
|
||||
NPU_OP_RMS_NORM, is_unary_op_supported,
|
||||
|
|
@ -412,6 +403,13 @@ constexpr const op_capabilities kOpCapabilities[] = {
|
|||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, true, // requires_thread_barrier
|
||||
},
|
||||
{
|
||||
NPU_OP_ROPE, hexagon::is_rope_supported,
|
||||
{
|
||||
hexagon::rope_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, false, // requires_thread_barrier
|
||||
},
|
||||
};
|
||||
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
|
||||
|
|
@ -424,6 +422,7 @@ static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
|
|||
"kOpArray[NPU_OP_RMS_NORM].op != NPU_OP_RMS_NORM");
|
||||
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
|
||||
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
|
||||
static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE");
|
||||
|
||||
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
|
||||
if (op >= NPU_OP_COUNT) {
|
||||
|
|
@ -451,17 +450,18 @@ bool requires_thread_barrier(npu_device_tensor_op op) {
|
|||
|
||||
bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
if (get_compute_func_impl(op, dst->type) == nullptr) {
|
||||
DEVICE_LOG_ERROR("[%s]unsupported, get_compute_func failed\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
auto is_supported_func = kOpCapabilities[op].is_supported;
|
||||
if (!is_supported_func || !is_supported_func(op, dst, srcs, src_len)) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (get_compute_func_impl(op, dst->type) == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
|
||||
get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,19 +9,24 @@ namespace {
|
|||
|
||||
template <typename T> struct get_data_type {};
|
||||
|
||||
template <typename _TyData> struct get_data_type<float (*)(const _TyData *, const _TyData *, size_t)> {
|
||||
using type = _TyData;
|
||||
template <typename _TyData0, typename _TyData1>
|
||||
struct get_data_type<float (*)(const _TyData0 *, const _TyData1 *, size_t)> {
|
||||
using data_type0 = _TyData0;
|
||||
using data_type1 = _TyData1;
|
||||
};
|
||||
|
||||
template <auto _DotFunc>
|
||||
template <auto _DotFunc, bool _IsQuantized>
|
||||
void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst,
|
||||
hexagon::compute_params * params) {
|
||||
using data_type = typename get_data_type<decltype(_DotFunc)>::type;
|
||||
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
|
||||
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
|
||||
|
||||
static_assert(!_IsQuantized || std::is_same_v<data_type0, hexagon::dequant_target_type>,
|
||||
"data_type0 must be the same as hexagon::dequant_target_type");
|
||||
|
||||
const bool is_quantized = hexagon::is_quantized_type(src0->get_type());
|
||||
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0);
|
||||
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
|
||||
if (is_quantized && dequantize_row_func == nullptr) {
|
||||
if (_IsQuantized && dequantize_row_func == nullptr) {
|
||||
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
|
||||
return;
|
||||
}
|
||||
|
|
@ -36,10 +41,10 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
|
||||
if (total_planes >= params->get_thread_count()) {
|
||||
start_end_plane = params->get_work_slice(total_planes);
|
||||
} else if (dst->get_ne(1) >= params->get_thread_count()) {
|
||||
start_end_row = params->get_work_slice(dst->get_ne(1));
|
||||
} else {
|
||||
} else if (dst->get_ne(0) >= params->get_thread_count()) {
|
||||
start_end_element = params->get_work_slice(dst->get_ne(0));
|
||||
} else {
|
||||
start_end_row = params->get_work_slice(dst->get_ne(1));
|
||||
}
|
||||
|
||||
if (start_end_plane.second <= start_end_plane.first || start_end_row.second <= start_end_row.first ||
|
||||
|
|
@ -57,30 +62,29 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
size_t src0_plane_cache_size = 0;
|
||||
uint8_t * src0_plane_cache_ptr = nullptr;
|
||||
const uint8_t * last_cached_plane_ptr = nullptr;
|
||||
bool is_mem_cache = false;
|
||||
if (is_quantized) {
|
||||
if constexpr (_IsQuantized) {
|
||||
src0_plane_slice_row_count =
|
||||
std::min(params->get_vtcm_quota_size() / src0_actual_row_size, src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
src0_plane_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size);
|
||||
if (src0_plane_cache_ptr == nullptr) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
DEVICE_LOG_ERROR(
|
||||
"mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, "
|
||||
"src0_actual_row_size: %zu, will fallback to mem cache\n",
|
||||
src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size);
|
||||
src0_plane_cache_ptr = params->get_mem_cache(src0_plane_cache_size);
|
||||
is_mem_cache = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
|
||||
"%p(%zu)\n",
|
||||
src0_actual_row_size, src0_plane_slice_row_count, is_quantized, (void *) src0_plane_cache_ptr,
|
||||
src0_actual_row_size, src0_plane_slice_row_count, _IsQuantized, (void *) src0_plane_cache_ptr,
|
||||
src0_plane_cache_size);
|
||||
|
||||
const size_t valid_row_bytes = src1->get_ne(0) * sizeof(data_type);
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(dst, params->get_thread_index(), dequant);
|
||||
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0);
|
||||
const size_t valid_row1_bytes = src1->get_ne(0) * sizeof(data_type1);
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
|
||||
|
||||
uint8_t * dst_ptr = dst->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
|
|
@ -89,8 +93,9 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
return;
|
||||
}
|
||||
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
constexpr bool should_fetch_src0_row = !_IsQuantized;
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) {
|
||||
const auto i3 = ip / dst->get_ne(2);
|
||||
const auto i2 = ip - i3 * dst->get_ne(2);
|
||||
|
|
@ -98,21 +103,25 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2);
|
||||
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
|
||||
col_idx += src0_plane_slice_row_count) {
|
||||
const auto * src0_plane =
|
||||
const auto actual_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - col_idx); // number of rows in this slice
|
||||
const uint8_t * src0_plane =
|
||||
src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1);
|
||||
if (src0_plane_cache_ptr) {
|
||||
if constexpr (_IsQuantized) {
|
||||
if (last_cached_plane_ptr != src0_plane) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(dequant);
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
|
||||
|
||||
for (int64_t ir = 0; ir < (int64_t) src0_plane_slice_row_count; ir++) {
|
||||
for (int64_t ir = 0; ir < (int64_t) actual_row_count; ir++) {
|
||||
auto * src0_row = src0_plane + ir * src0->get_nb(1);
|
||||
if (ir + 1 < src0_plane_slice_row_count) {
|
||||
if (ir + 1 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
|
||||
}
|
||||
|
||||
auto * dst_row = reinterpret_cast<float *>(src0_plane_cache_ptr + ir * src0_actual_row_size);
|
||||
dequantize_row_func(src0_row, reinterpret_cast<float *>(dst_row), src0->get_ne(0),
|
||||
params->f16_to_f32_table);
|
||||
auto * dst_row = reinterpret_cast<hexagon::dequant_target_type *>(src0_plane_cache_ptr +
|
||||
ir * src0_actual_row_size);
|
||||
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_target_type *>(dst_row),
|
||||
src0->get_ne(0));
|
||||
}
|
||||
|
||||
last_cached_plane_ptr = src0_plane;
|
||||
|
|
@ -121,22 +130,43 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
src0_plane = src0_plane_cache_ptr;
|
||||
}
|
||||
|
||||
if (start_end_row.second > start_end_row.first) {
|
||||
hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_row1_bytes);
|
||||
}
|
||||
|
||||
for (int64_t i1 = start_end_row.first; i1 < start_end_row.second; i1++) {
|
||||
auto * src1_row = src1_plane + i1 * src1->get_nb(1);
|
||||
auto * dst_row = reinterpret_cast<float *>(dst_plane + i1 * dst->get_nb(1)) + col_idx;
|
||||
for (int64_t i0 = 0; i0 < (int64_t) src0_plane_slice_row_count; i0++) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot);
|
||||
auto * src1_row = src1_plane + i1 * src1->get_nb(1);
|
||||
auto * dst_row = reinterpret_cast<float *>(dst_plane + i1 * dst->get_nb(1)) + col_idx;
|
||||
int64_t i0 = 0;
|
||||
for (; i0 + 1 < (int64_t) actual_row_count; i0 += 2) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
if (i0 + 1 < src0_plane_slice_row_count) {
|
||||
if (!src0_plane_cache_ptr || is_mem_cache) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size, valid_row_bytes);
|
||||
}
|
||||
} else if (ip + 1 < start_end_plane.second) {
|
||||
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes);
|
||||
if constexpr (should_fetch_src0_row) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
dst_row[i0] = _DotFunc(reinterpret_cast<const data_type *>(src0_row),
|
||||
reinterpret_cast<const data_type *>(src1_row), (size_t) src0->get_ne(0));
|
||||
dst_row[i0] = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
|
||||
if (should_fetch_src0_row && i0 + 2 < (int64_t) actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
dst_row[i0 + 1] =
|
||||
_DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_actual_row_size),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
}
|
||||
|
||||
if (ip + 1 < start_end_plane.second) {
|
||||
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row1_bytes);
|
||||
}
|
||||
|
||||
if (i0 < (int64_t) actual_row_count) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
dst_row[i0] = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -146,7 +176,7 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
}
|
||||
|
||||
bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
|
||||
if (src1.type != NPU_DATA_TYPE_F32) {
|
||||
if (src1.type != NPU_DATA_TYPE_F32 && src1.type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src1 is not F32\n",
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
|
||||
return false;
|
||||
|
|
@ -166,7 +196,7 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n
|
|||
}
|
||||
|
||||
const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota();
|
||||
if (src0.ne[0] * sizeof(hexagon::dequantized_element_type) > vtcm_thread_quota_size) {
|
||||
if (src0.ne[0] * sizeof(hexagon::dequant_target_type) > vtcm_thread_quota_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n",
|
||||
hexagon::get_type_name(src0.type), (long) src0.ne[0], vtcm_thread_quota_size);
|
||||
return false;
|
||||
|
|
@ -177,29 +207,113 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n
|
|||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<float>();
|
||||
const auto * src0_ptr = is_src0_quantized ?
|
||||
src1->get_read_buffer_as<npu_device_fp16_t>() :
|
||||
src0->get_read_buffer_as<npu_device_fp16_t>(); // skip src0 for quantized tensors
|
||||
|
||||
if (!hexagon::is_f16_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f16_f16_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<npu_device_fp16_t>();
|
||||
const auto * src0_ptr = is_src0_quantized ? src1_ptr : src0->get_read_buffer_as<npu_device_fp16_t>();
|
||||
|
||||
if (!hexagon::is_f16_f16_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<float>();
|
||||
const auto * src0_ptr = src0->get_read_buffer_as<float>();
|
||||
|
||||
if (!hexagon::is_f32_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
typedef void (*mul_mat_func_type)(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst,
|
||||
hexagon::compute_params * params);
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16F32Funcs[2][2] = {
|
||||
{
|
||||
// non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f32, false>, // F32 * F32 unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f32, false>, // F32 * F32 aligned
|
||||
},
|
||||
{
|
||||
// quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f32, true>, // F32 * F32 quantized aligned
|
||||
},
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16Funcs[2][2] = {
|
||||
{
|
||||
// non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f16, false>, // F16 * F16 unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f16, false>, // F16 * F16 aligned
|
||||
},
|
||||
{
|
||||
// quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f16, true>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f16, true>, // F16 * F16 quantized aligned
|
||||
},
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
|
||||
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "mul_mat_f32 requires max dims 4");
|
||||
static_assert(std::is_same<hexagon::dequant_target_type, float>::value ||
|
||||
std::is_same<hexagon::dequant_target_type, npu_device_fp16_t>::value,
|
||||
"dequant_target_type must be float or npu_device_fp16_t");
|
||||
|
||||
if (!out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "mul_mat_f32 requires max dims 4");
|
||||
auto * src0 = out->get_src(0);
|
||||
auto * src1 = out->get_src(1);
|
||||
if (!src0 || !src1) {
|
||||
return true; // skip if no src
|
||||
}
|
||||
|
||||
const bool is_src0_quantized = is_quantized_type(src0->get_type());
|
||||
switch (src1->get_type()) {
|
||||
case NPU_DATA_TYPE_F32:
|
||||
mul_mat_impl<hexagon::vec_dot_product_f32_f32>(src0, src1, out, params);
|
||||
if (is_src0_quantized || src0->get_type() == NPU_DATA_TYPE_F16) {
|
||||
kMulMatF16F32Funcs[is_src0_quantized][is_mul_mat_f16_f32_src_tensors_aligned(
|
||||
src0, src1, is_src0_quantized)](src0, src1, out, params);
|
||||
} else {
|
||||
if (is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)) {
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f32_f32, false>(src0, src1, out, params);
|
||||
} else {
|
||||
mul_mat_impl<hexagon::vec_dot_product_f32_f32, false>(src0, src1, out, params);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
case NPU_DATA_TYPE_F16:
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f16>(src0, src1, out, params);
|
||||
kMulMatF16Funcs[is_src0_quantized][is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](
|
||||
src0, src1, out, params);
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
|
|
@ -229,6 +343,12 @@ bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec
|
|||
const auto & src0 = srcs[0];
|
||||
const auto & src1 = srcs[1];
|
||||
if (src0.type != src1.type) {
|
||||
if (src1.type == NPU_DATA_TYPE_F32 && src0.type == NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type(%s) and src1.type(%s) mismatch, but src0 is F16 and src1 is F32\n",
|
||||
op_get_name(op), get_type_name(src0.type), get_type_name(src1.type));
|
||||
return true; // F16 * F32 is supported
|
||||
}
|
||||
|
||||
#ifdef GGML_HEXAGON_ENABLE_QUANTIZED_TENSORS
|
||||
if (!is_quantized_mul_mat_supported(src0, src1)) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,368 @@
|
|||
#include "op_rope.hpp"
|
||||
|
||||
#include "type_traits.hpp"
|
||||
|
||||
#ifndef M_PI
|
||||
# define M_PI (3.14159265358979323846)
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
float rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||
return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float) M_PI)) / (2 * logf(base));
|
||||
}
|
||||
|
||||
void rope_yarn_corr_dims(int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]) {
|
||||
// start and end correction dims
|
||||
float start = floorf(rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
|
||||
float end = ceilf(rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
|
||||
dims[0] = std::max<float>(0, start);
|
||||
dims[1] = std::min<float>(n_dims - 1, end);
|
||||
}
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / std::max<float>(0.001f, high - low);
|
||||
return 1 - std::min<float>(1, std::max<float>(0, y));
|
||||
}
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
}
|
||||
*cos_theta = cosf(theta) * mscale;
|
||||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
||||
void rope_cache_init(float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0,
|
||||
float ext_factor, float mscale, float * cache, float sin_sign, float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
float theta = theta_base;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
|
||||
rope_yarn(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]);
|
||||
cache[i0 + 1] *= sin_sign;
|
||||
|
||||
theta *= theta_scale;
|
||||
}
|
||||
}
|
||||
|
||||
void mrope_cache_init(float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e,
|
||||
const int sections[4], bool indep_sects, float freq_scale, const float * freq_factors,
|
||||
float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float * cache, float sin_sign,
|
||||
float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
float theta_t = theta_base_t;
|
||||
float theta_h = theta_base_h;
|
||||
float theta_w = theta_base_w;
|
||||
float theta_e = theta_base_e; // extra position id for vision encoder
|
||||
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
||||
int sec_w = sections[1] + sections[0];
|
||||
int sec_e = sections[2] + sec_w;
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
|
||||
|
||||
int sector = (i0 / 2) % sect_dims;
|
||||
if (indep_sects) {
|
||||
// compute theta independently for each dim sections
|
||||
// (i.e. reset corresponding theta when `i0` go from one section to another)
|
||||
if (sector == 0) {
|
||||
theta_t = theta_base_t;
|
||||
} else if (sector == sections[0]) {
|
||||
theta_h = theta_base_h;
|
||||
} else if (sector == sec_w) {
|
||||
theta_w = theta_base_w;
|
||||
} else if (sector == sec_e) {
|
||||
theta_e = theta_base_e;
|
||||
}
|
||||
}
|
||||
|
||||
float theta = theta_t;
|
||||
if (sector >= sections[0] && sector < sec_w) {
|
||||
theta = theta_h;
|
||||
} else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
||||
theta = theta_w;
|
||||
} else if (sector >= sec_w + sections[2]) {
|
||||
theta = theta_e;
|
||||
}
|
||||
|
||||
rope_yarn(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]);
|
||||
cache[i0 + 1] *= sin_sign;
|
||||
|
||||
theta_t *= theta_scale;
|
||||
theta_w *= theta_scale;
|
||||
theta_h *= theta_scale;
|
||||
theta_e *= theta_scale;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool _IsNeoX, bool _IsMrope, bool _IsVision>
|
||||
bool rope_impl(hexagon::tensor * out, hexagon::compute_params * params) {
|
||||
const auto * src0 = out->get_src(0);
|
||||
const auto * src1 = out->get_src(1);
|
||||
const auto * src2 = out->get_src(2);
|
||||
|
||||
const int n_dims = out->get_op_param<int32_t>(1);
|
||||
const int n_ctx_orig = out->get_op_param<int32_t>(4);
|
||||
const int sections[4] = {
|
||||
out->get_op_param<int32_t>(11),
|
||||
out->get_op_param<int32_t>(12),
|
||||
out->get_op_param<int32_t>(13),
|
||||
out->get_op_param<int32_t>(14),
|
||||
};
|
||||
|
||||
const float freq_base = out->get_op_param<float>(5);
|
||||
const float freq_scale = out->get_op_param<float>(6);
|
||||
const float ext_factor = out->get_op_param<float>(7);
|
||||
const float attn_factor = out->get_op_param<float>(8);
|
||||
const float beta_fast = out->get_op_param<float>(9);
|
||||
const float beta_slow = out->get_op_param<float>(10);
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
if (_IsMrope && sections[0] <= 0 && sections[1] <= 0 && sections[2] <= 0) {
|
||||
DEVICE_LOG_ERROR("[ROPE]invalid sections for MROPE: %d, %d, %d\n", sections[0], sections[1], sections[2]);
|
||||
return false; // invalid sections for MROPE
|
||||
}
|
||||
|
||||
if (n_dims % 2 || (_IsVision && n_dims != out->get_ne(0) / 2)) {
|
||||
DEVICE_LOG_ERROR("[ROPE]invalid n_dims for vision ROPE: %d, expected: %d\n", n_dims, out->get_ne(0) / 2);
|
||||
return false; // invalid n_dims for vision ROPE
|
||||
}
|
||||
|
||||
// cache size is (ne0 + CACHE_LINE_SIZE_F32)
|
||||
const size_t total_cache_size = hexagon::get_aligned_size(out->get_ne(0) * sizeof(float));
|
||||
auto * cache_ptr = params->get_vtcm_cache(total_cache_size);
|
||||
if (!cache_ptr) {
|
||||
DEVICE_LOG_ERROR("[ROPE]Failed to allocate VTCM cache for flash_attn: %zu bytes\n", total_cache_size);
|
||||
return false; // failed to allocate cache
|
||||
}
|
||||
|
||||
const float * freq_factors = nullptr;
|
||||
if (src2 != nullptr) {
|
||||
if (src2->get_type() != NPU_DATA_TYPE_F32 || src2->get_ne(0) < n_dims / 2) {
|
||||
DEVICE_LOG_ERROR("[ROPE]src2 type is not F32 or F16: %s\n", hexagon::get_type_name(src2->get_type()));
|
||||
return false; // unsupported src2 type
|
||||
}
|
||||
|
||||
freq_factors = src2->get_read_buffer_as<float>();
|
||||
}
|
||||
|
||||
const int64_t total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
|
||||
const auto start_end_row = params->get_work_slice(total_rows);
|
||||
const auto start_end_plane =
|
||||
std::pair<int64_t, int64_t>{ start_end_row.first / out->get_ne(1),
|
||||
(start_end_row.second + out->get_ne(1) - 1) / out->get_ne(1) };
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(out, params->get_thread_index(), rope);
|
||||
|
||||
const float sin_sign = 1.0f;
|
||||
const int32_t * pos = src1->get_read_buffer_as<int32_t>();
|
||||
const uint8_t * src0_data_ptr = src0->get_read_buffer();
|
||||
uint8_t * dst_data_ptr = out->get_write_buffer();
|
||||
for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) {
|
||||
int64_t i3 = ip / out->get_ne(2); // batch
|
||||
int64_t i2 = ip % out->get_ne(2); // seq-len
|
||||
float * cache = reinterpret_cast<float *>(cache_ptr);
|
||||
if constexpr (!_IsMrope) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache);
|
||||
const int64_t p = pos[i2];
|
||||
rope_cache_init(p, freq_scale, freq_factors, corr_dims, out->get_ne(0), ext_factor, attn_factor, cache,
|
||||
sin_sign, theta_scale);
|
||||
} else {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache);
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + out->get_ne(2)];
|
||||
const int64_t p_w = pos[i2 + out->get_ne(2) * 2];
|
||||
const int64_t p_e = pos[i2 + out->get_ne(2) * 3];
|
||||
mrope_cache_init(p_t, p_h, p_w, p_e, sections, _IsVision, freq_scale, freq_factors, corr_dims,
|
||||
out->get_ne(0), ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 1, loop);
|
||||
const uint8_t * src0_plane = src0_data_ptr + i3 * src0->get_nb(3) + i2 * src0->get_nb(2);
|
||||
uint8_t * dst_plane = dst_data_ptr + i3 * out->get_nb(3) + i2 * out->get_nb(2);
|
||||
const int64_t start_row = ip == start_end_plane.first ? (start_end_row.first % out->get_ne(1)) : 0;
|
||||
const int64_t end_row = ip == start_end_plane.second ? (start_end_row.second % out->get_ne(1)) :
|
||||
out->get_ne(1); // end row is exclusive
|
||||
for (int64_t i1 = start_row; i1 < end_row; i1++) { // attn-heads
|
||||
const uint8_t * src0_row = src0_plane + i1 * src0->get_nb(1);
|
||||
uint8_t * dst_row = dst_plane + i1 * out->get_nb(1);
|
||||
if constexpr (_IsNeoX || _IsMrope) {
|
||||
if constexpr (_IsVision) {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0 / 2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *) (src0_row + ic * src0->get_nb(0));
|
||||
float * dst_data = (float *) (dst_row + ic * out->get_nb(0));
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data[n_dims] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0 / 2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *) (src0_row + ic * src0->get_nb(0));
|
||||
float * dst_data = (float *) (dst_row + ic * out->get_nb(0));
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims / 2];
|
||||
|
||||
dst_data[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data[n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *) (src0_row + i0 * src0->get_nb(0));
|
||||
float * dst_data = (float *) (dst_row + i0 * out->get_nb(0));
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data[1] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (_IsVision) {
|
||||
for (int64_t i0 = n_dims; i0 < out->get_ne(0); i0 += 2) {
|
||||
const int64_t ic = i0 / 2;
|
||||
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
const float * const src = (float *) (src0_row + ic * src0->get_nb(0));
|
||||
float * dst_data = (float *) (dst_row + ic * out->get_nb(0));
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data[n_dims] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
} else {
|
||||
// fill the remain channels with data from src tensor
|
||||
memcpy(dst_row + n_dims * out->get_nb(0), src0_row + n_dims * src0->get_nb(0),
|
||||
(out->get_ne(0) - n_dims) * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out->release_write_buffer();
|
||||
return true;
|
||||
}
|
||||
|
||||
typedef bool (*rope_impl_func)(hexagon::tensor * out, hexagon::compute_params * params);
|
||||
|
||||
constexpr const rope_impl_func kRopeImplFuncs[8] = {
|
||||
rope_impl<false, false, false>, // IsNotNeoX, IsNotMrope, IsNotVision
|
||||
rope_impl<false, false, true>, // IsNotNeoX, IsNotMrope, IsVision
|
||||
rope_impl<false, true, false>, // IsNotNeoX, IsMrope, IsNotVision
|
||||
rope_impl<false, true, true>, // IsNotNeoX, IsMrope, IsVision
|
||||
rope_impl<true, false, false>, // IsNeoX, IsNotMrope, IsNotVision
|
||||
rope_impl<true, false, true>, // IsNeoX, IsNotMrope, IsVision
|
||||
rope_impl<true, true, false>, // IsNeoX, IsMrope, IsNotVision
|
||||
rope_impl<true, true, true>, // IsNeoX, IsMrope, IsVision
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool rope_f32(tensor * out, compute_params * params) {
|
||||
const int mode = out->get_op_param<int32_t>(2);
|
||||
const bool is_neox = mode & NPU_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & NPU_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||
const bool is_vision = mode == NPU_ROPE_TYPE_VISION;
|
||||
|
||||
size_t impl_index = is_neox ? 4 : 0;
|
||||
impl_index += is_mrope ? 2 : 0;
|
||||
impl_index += is_vision ? 1 : 0;
|
||||
|
||||
if (impl_index >= sizeof(kRopeImplFuncs) / sizeof(kRopeImplFuncs[0])) {
|
||||
DEVICE_LOG_ERROR("[ROPE]invalid impl_index: %zu\n", impl_index);
|
||||
return false; // invalid impl index
|
||||
}
|
||||
|
||||
return kRopeImplFuncs[impl_index](out, params);
|
||||
}
|
||||
|
||||
bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
if (op != NPU_OP_ROPE) {
|
||||
DEVICE_LOG_DEBUG("[%s]op is not ROPE\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src_len < 2 || !dst || !srcs) {
|
||||
// freq can be optional, but we require at least 2 srcs: src0 and src1
|
||||
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst type is not F32: %s\n", op_get_name(op), get_type_name(dst->type));
|
||||
return false; // add more dst type if needed
|
||||
}
|
||||
|
||||
const auto & src0 = srcs[0];
|
||||
if (src0.type != dst->type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 type is not the same as dst type: %s vs %s\n", op_get_name(op),
|
||||
get_type_name(src0.type), get_type_name(dst->type));
|
||||
return false; // unsupported src0 type
|
||||
}
|
||||
|
||||
const auto & src1 = srcs[1];
|
||||
if (src1.type != NPU_DATA_TYPE_I32) {
|
||||
DEVICE_LOG_DEBUG("[%s]src1 type is not I32: %s\n", op_get_name(op), get_type_name(src1.type));
|
||||
return false; // unsupported src1 type
|
||||
}
|
||||
|
||||
if (src_len > 2) {
|
||||
const auto & src2 = srcs[2];
|
||||
if (src2.type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]src2 type is not F32: %s\n", op_get_name(op), get_type_name(src2.type));
|
||||
return false; // unsupported src2 type
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[%s]freq is present\n", op_get_name(op));
|
||||
}
|
||||
|
||||
if (!is_same_shape(src0, *dst)) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: check the params for ROPE operation
|
||||
return true; // ROPE operation is not supported yet
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include "op_types.hpp"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool rope_f32(tensor * out, compute_params * params);
|
||||
bool is_rope_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -110,6 +110,15 @@ class tensor {
|
|||
return _data + _info.offset;
|
||||
}
|
||||
|
||||
template <typename _Ty> const _Ty * get_read_buffer_as() const {
|
||||
const auto * buffer = get_read_buffer();
|
||||
if (!buffer) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return reinterpret_cast<const _Ty *>(buffer);
|
||||
}
|
||||
|
||||
uint8_t * get_write_buffer() const {
|
||||
if (_info.is_constant) {
|
||||
DEVICE_LOG_ERROR("Attempt to write to a constant tensor: %p", (void *) this);
|
||||
|
|
|
|||
|
|
@ -29,24 +29,38 @@ inline npu_device_fp16_t to_fp16(const float src) {
|
|||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector load_block_generic(const _TBlock & src) {
|
||||
uint8_t buffer[hexagon::kBytesPerVector];
|
||||
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock), "wrong q4_0 block size/padding");
|
||||
|
||||
static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding");
|
||||
static_assert(sizeof(buffer) >= sizeof(src.qs), "wrong q4_0 block size/padding");
|
||||
|
||||
memcpy(&buffer[0], src.qs, sizeof(src.qs));
|
||||
return *reinterpret_cast<HVX_UVector *>(buffer);
|
||||
const HVX_Vector * qs0 = reinterpret_cast<const HVX_Vector *>(src.qs);
|
||||
const HVX_Vector * qs1 = qs0 + 1;
|
||||
return Q6_V_valign_VVR(*qs1, *qs0, (size_t) src.qs);
|
||||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector load_dual_block_generic(const _TBlock & src1, const _TBlock & src2) {
|
||||
uint8_t buffer[hexagon::kBytesPerVector];
|
||||
template <typename _TBlock> inline HVX_Vector load_dual_block_generic(const _TBlock * srcs) {
|
||||
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong q4_0 block size/padding");
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
|
||||
|
||||
static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding");
|
||||
static_assert(sizeof(buffer) >= sizeof(src1.qs) * 2, "wrong q4_0 block size/padding");
|
||||
const HVX_Vector * qs0 = reinterpret_cast<const HVX_Vector *>(srcs->qs);
|
||||
const HVX_Vector * qs1 = qs0 + 1;
|
||||
HVX_Vector blocks = Q6_V_valign_VVR(*qs1, *qs0, (size_t) srcs->qs);
|
||||
HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock));
|
||||
return Q6_V_lo_W(Q6_W_vshuff_VVR(block1, blocks, kSizeOfQs));
|
||||
}
|
||||
|
||||
memcpy(&buffer[0], src1.qs, sizeof(src1.qs));
|
||||
memcpy(&buffer[sizeof(src1.qs)], src2.qs, sizeof(src2.qs));
|
||||
return *reinterpret_cast<HVX_UVector *>(buffer);
|
||||
template <typename _TBlock> inline HVX_Vector load_qual_block_generic(const _TBlock * srcs) {
|
||||
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 4, "wrong q4_0 block size/padding");
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
|
||||
|
||||
const HVX_Vector * qs0 = reinterpret_cast<const HVX_Vector *>(srcs->qs);
|
||||
const HVX_Vector * qs1 = qs0 + 1;
|
||||
HVX_Vector blocks = Q6_V_valign_VVR(*qs1, *qs0, (size_t) srcs->qs);
|
||||
HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock));
|
||||
HVX_Vector block2 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 2);
|
||||
HVX_Vector block3 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 3);
|
||||
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(block1, blocks, kSizeOfQs);
|
||||
HVX_VectorPair qp1 = Q6_W_vshuff_VVR(block3, block2, kSizeOfQs);
|
||||
return Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_lo_W(qp1), Q6_V_lo_W(qp0), kSizeOfQs * 2));
|
||||
}
|
||||
|
||||
inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) {
|
||||
|
|
@ -148,7 +162,7 @@ float make_qkx2_quants(int n, int nmax, const float * x, const float * weights,
|
|||
return scale;
|
||||
}
|
||||
|
||||
void quantize_row_fp16(const float * src, void * dst, size_t count, const float * f16_to_f32_table) {
|
||||
void quantize_row_fp16(const float * src, void * dst, size_t count) {
|
||||
auto * out = reinterpret_cast<npu_device_fp16_t *>(dst);
|
||||
// TODO: use hvx intrinsics for better performance
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
|
|
@ -156,7 +170,7 @@ void quantize_row_fp16(const float * src, void * dst, size_t count, const float
|
|||
}
|
||||
}
|
||||
|
||||
void quantize_row_q8_0(const float * src, void * dst, size_t count, const float * f16_to_f32_table) {
|
||||
void quantize_row_q8_0(const float * src, void * dst, size_t count) {
|
||||
const int nb = count / QUANT_BLOCK_SIZE;
|
||||
auto * out = reinterpret_cast<npu_device_block_q8_0 *>(dst);
|
||||
|
||||
|
|
@ -181,7 +195,7 @@ void quantize_row_q8_0(const float * src, void * dst, size_t count, const float
|
|||
}
|
||||
}
|
||||
|
||||
void quantize_row_q4_0(const float * src, void * dst, size_t count, const float * f16_to_f32_table) {
|
||||
void quantize_row_q4_0(const float * src, void * dst, size_t count) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
|
||||
const int nb = count / qk;
|
||||
|
|
@ -217,7 +231,7 @@ void quantize_row_q4_0(const float * src, void * dst, size_t count, const float
|
|||
}
|
||||
}
|
||||
|
||||
void quantize_row_q4_K(const float * src, void * dst, size_t count, const float * f16_to_f32_table) {
|
||||
void quantize_row_q4_K(const float * src, void * dst, size_t count) {
|
||||
const int nb = count / QUANT_K_BLOCK_SIZE;
|
||||
auto * out = reinterpret_cast<npu_device_block_q4_k *>(dst);
|
||||
|
||||
|
|
@ -274,11 +288,11 @@ void quantize_row_q4_K(const float * src, void * dst, size_t count, const float
|
|||
uint8_t sc, m;
|
||||
for (int j = 0; j < QUANT_K_BLOCK_SIZE / 32; ++j) {
|
||||
get_scale_min_k4(j, out[i].scales, &sc, &m);
|
||||
const float d = f16_to_f32_table[out[i].d] * sc;
|
||||
const float d = to_float(out[i].d) * sc;
|
||||
if (!d) {
|
||||
continue;
|
||||
}
|
||||
const float dm = f16_to_f32_table[out[i].dmin] * m;
|
||||
const float dm = to_float(out[i].dmin) * m;
|
||||
for (int ii = 0; ii < 32; ++ii) {
|
||||
int l = nearest_int((src[32 * j + ii] + dm) / d);
|
||||
l = std::max<int>(0, std::min<int>(15, l));
|
||||
|
|
@ -298,90 +312,158 @@ void quantize_row_q4_K(const float * src, void * dst, size_t count, const float
|
|||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q8_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) {
|
||||
void dequantize_row_q8_0(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
|
||||
|
||||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q8_0 *>(src);
|
||||
HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access
|
||||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q8_0 *>(src);
|
||||
auto * dst_ptr = ((hexagon::dequant_target_type *) dst); // TODO: opt for aligned access
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
int i = 0;
|
||||
for (; i + 1 < nb; i += 2) {
|
||||
const auto & src0 = src_ptr[i];
|
||||
const auto & src1 = src_ptr[i + 1];
|
||||
|
||||
HVX_Vector scales01 =
|
||||
Q6_V_valign_VVR(Q6_Vh_vsplat_R(src1.d), Q6_Vh_vsplat_R(src0.d), hexagon::kBytesPerVector / 2);
|
||||
|
||||
HVX_Vector qs = load_dual_block_generic(src_ptr + i);
|
||||
HVX_Vector q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(Q6_Wh_vunpack_Vb(qs)));
|
||||
HVX_Vector result = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01);
|
||||
|
||||
*reinterpret_cast<HVX_UVector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(result);
|
||||
dst_ptr += qk * 2;
|
||||
}
|
||||
|
||||
if (i < nb) {
|
||||
const auto & src = src_ptr[i];
|
||||
HVX_Vector d = Q6_Vh_vsplat_R(src.d);
|
||||
|
||||
HVX_Vector q_lo = load_block_generic(src);
|
||||
HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q));
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q));
|
||||
q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d);
|
||||
out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q));
|
||||
HVX_Vector scales = Q6_Vh_vsplat_R(src.d);
|
||||
|
||||
HVX_Vector q_lo = load_block_generic(src);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(Q6_Wh_vunpack_Vb(q_lo)));
|
||||
HVX_Vector result = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales);
|
||||
hexagon::q6op_vstu_variable_ARV<hexagon::kBytesPerVector / 2>(
|
||||
dst_ptr,
|
||||
Q6_Vhf_equals_Vqf16(result)); // TODO: opt the store
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) {
|
||||
template <bool _IsDstAligned>
|
||||
void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
static_assert(qk % 2 == 0, "qk must be even");
|
||||
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
|
||||
|
||||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src);
|
||||
HVX_Vector mask = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector minus = Q6_Vb_vsplat_R(8);
|
||||
HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access
|
||||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src);
|
||||
const HVX_Vector mask = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector minus = Q6_Vb_vsplat_R(8);
|
||||
hexagon::dequant_target_type * dst_ptr = dst; // TODO: opt for aligned access
|
||||
|
||||
const int loop_count = nb - (nb % 2);
|
||||
for (int i = 0; i < loop_count; i += 2) {
|
||||
const auto & src1 = src_ptr[i];
|
||||
const auto & src2 = src_ptr[i + 1];
|
||||
int i = 0;
|
||||
for (; i + 3 < nb; i += 4) {
|
||||
const auto & src0 = src_ptr[i];
|
||||
const auto & src1 = src_ptr[i + 1];
|
||||
const auto & src2 = src_ptr[i + 2];
|
||||
const auto & src3 = src_ptr[i + 3];
|
||||
|
||||
HVX_Vector d1 = Q6_Vh_vsplat_R(src1.d);
|
||||
HVX_Vector d2 = Q6_Vh_vsplat_R(src2.d);
|
||||
HVX_Vector d = Q6_Vh_vshuff_Vh(Q6_V_valign_VVR(d2, d1, hexagon::kBytesPerVector / 2));
|
||||
HVX_Vector scales01 =
|
||||
Q6_V_valign_VVR(Q6_Vh_vsplat_R(src1.d), Q6_Vh_vsplat_R(src0.d), hexagon::kBytesPerVector / 2);
|
||||
HVX_Vector scales23 =
|
||||
Q6_V_valign_VVR(Q6_Vh_vsplat_R(src3.d), Q6_Vh_vsplat_R(src2.d), hexagon::kBytesPerVector / 2);
|
||||
|
||||
HVX_Vector q_lo = load_dual_block_generic(src1, src2);
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4);
|
||||
HVX_VectorPair q = Q6_W_vshuff_VVR(q_hi, Q6_V_vand_VV(q_lo, mask), kSizeOfQs);
|
||||
q_lo = Q6_V_valign_VVR(Q6_V_lo_W(q), Q6_V_vzero(), hexagon::kBytesPerVector / 2);
|
||||
q_lo = Q6_V_valign_VVR(Q6_V_hi_W(q), q_lo, hexagon::kBytesPerVector / 2);
|
||||
q_lo = Q6_Vb_vshuff_Vb(q_lo);
|
||||
q_lo = Q6_Vb_vsub_VbVb(q_lo, minus);
|
||||
q = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q));
|
||||
q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d);
|
||||
out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q));
|
||||
out[i + 1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(q));
|
||||
HVX_Vector qs = load_qual_block_generic(src_ptr + i);
|
||||
HVX_Vector q_lo = Q6_V_vand_VV(qs, mask);
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4));
|
||||
q_lo = Q6_Vb_vsub_VbVb(Q6_V_lo_W(qp0), minus);
|
||||
qp0 = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0));
|
||||
q_hi = Q6_Vhf_equals_Vh(Q6_V_hi_W(qp0));
|
||||
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01);
|
||||
q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scales23);
|
||||
|
||||
if constexpr (_IsDstAligned) {
|
||||
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(q_hi);
|
||||
} else {
|
||||
reinterpret_cast<HVX_UVector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(q_hi);
|
||||
}
|
||||
|
||||
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type) * 2;
|
||||
}
|
||||
|
||||
if (loop_count < nb) {
|
||||
for (; i + 1 < nb; i += 2) {
|
||||
const auto & src0 = src_ptr[i];
|
||||
const auto & src1 = src_ptr[i + 1];
|
||||
|
||||
HVX_Vector scales01 =
|
||||
Q6_V_valign_VVR(Q6_Vh_vsplat_R(src1.d), Q6_Vh_vsplat_R(src0.d), hexagon::kBytesPerVector / 2);
|
||||
|
||||
HVX_Vector qs = load_dual_block_generic(src_ptr + i);
|
||||
HVX_Vector q_lo = Q6_V_vand_VV(qs, mask);
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2));
|
||||
q_lo = Q6_Vb_vsub_VbVb(Q6_V_lo_W(qp0), minus);
|
||||
qp0 = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0));
|
||||
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01);
|
||||
|
||||
if constexpr (_IsDstAligned) {
|
||||
*reinterpret_cast<HVX_Vector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
} else {
|
||||
*reinterpret_cast<HVX_UVector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
}
|
||||
|
||||
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type);
|
||||
}
|
||||
|
||||
if (i < nb) {
|
||||
const auto & curr_blk = src_ptr[nb - 1];
|
||||
HVX_Vector d = Q6_Vh_vsplat_R(curr_blk.d);
|
||||
HVX_Vector scales = Q6_Vh_vsplat_R(curr_blk.d);
|
||||
|
||||
HVX_Vector q_lo = load_block_generic(curr_blk);
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4);
|
||||
q_lo = Q6_V_valign_VVR(Q6_V_vand_VV(q_lo, mask), Q6_V_vzero(), sizeof(curr_blk.qs));
|
||||
q_lo = Q6_V_valign_VVR(q_hi, q_lo, hexagon::kBytesPerVector - sizeof(curr_blk.qs));
|
||||
q_lo = Q6_Vb_vsub_VbVb(q_lo, minus);
|
||||
|
||||
HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q));
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q));
|
||||
q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d);
|
||||
out[nb - 1] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q));
|
||||
HVX_Vector qs = load_block_generic(curr_blk);
|
||||
HVX_Vector q_lo = Q6_V_vand_VV(qs, mask);
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4);
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs);
|
||||
q_lo = Q6_Vb_vsub_VbVb(Q6_V_lo_W(qp0), minus);
|
||||
qp0 = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0));
|
||||
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales);
|
||||
if constexpr (_IsDstAligned) {
|
||||
hexagon::q6op_vstu_variable_aligned<hexagon::kBytesPerVector / 2>(dst_ptr, Q6_Vhf_equals_Vqf16(q_lo));
|
||||
} else {
|
||||
hexagon::q6op_vstu_variable_ARV<hexagon::kBytesPerVector / 2>(
|
||||
dst_ptr,
|
||||
Q6_Vhf_equals_Vqf16(q_lo)); // TODO: opt the store
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_K(const void * src, float * dst, size_t count, const float * f16_to_f32_table) {
|
||||
void dequantize_row_q4_0(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
const bool dst_aligned = hexagon::is_addr_aligned(dst);
|
||||
if (dst_aligned) {
|
||||
dequantize_row_q4_0_impl<true>(src, dst, count);
|
||||
} else {
|
||||
dequantize_row_q4_0_impl<false>(src, dst, count);
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_K(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
const int nb = count / QUANT_K_BLOCK_SIZE;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_k *>(src);
|
||||
auto * dst_ptr = reinterpret_cast<__fp16 *>(dst);
|
||||
|
||||
// TODO: use intrinsics
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * q = src_ptr[i].qs;
|
||||
|
||||
const float d = f16_to_f32_table[src_ptr[i].d];
|
||||
const float min = f16_to_f32_table[src_ptr[i].dmin];
|
||||
const __fp16 d = reinterpret_cast<const __fp16 &>(src_ptr[i].d);
|
||||
const __fp16 min = reinterpret_cast<const __fp16 &>(src_ptr[i].dmin);
|
||||
|
||||
int is = 0;
|
||||
uint8_t sc = 0;
|
||||
|
|
@ -389,17 +471,17 @@ void dequantize_row_q4_K(const void * src, float * dst, size_t count, const floa
|
|||
const auto * scales = src_ptr[i].scales;
|
||||
for (int j = 0; j < QUANT_K_BLOCK_SIZE; j += 64) {
|
||||
get_scale_min_k4(is + 0, scales, &sc, &m);
|
||||
const float d1 = d * sc;
|
||||
const float m1 = min * m;
|
||||
const __fp16 d1 = d * sc;
|
||||
const __fp16 m1 = min * m;
|
||||
get_scale_min_k4(is + 1, scales, &sc, &m);
|
||||
const float d2 = d * sc;
|
||||
const float m2 = min * m;
|
||||
const __fp16 d2 = d * sc;
|
||||
const __fp16 m2 = min * m;
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
dst[0] = d1 * (q[l] & 0xF) - m1;
|
||||
dst[32] = d2 * ((q[l] >> 4) & 0xF) - m2;
|
||||
dst++;
|
||||
dst_ptr[0] = d1 * (q[l] & 0xF) - m1;
|
||||
dst_ptr[32] = d2 * ((q[l] >> 4) & 0xF) - m2;
|
||||
dst_ptr++;
|
||||
}
|
||||
dst += 32;
|
||||
dst_ptr += 32;
|
||||
q += 32;
|
||||
is += 2;
|
||||
}
|
||||
|
|
@ -412,9 +494,12 @@ template <typename _TData> struct dot_func_traits<float (*)(_TData, _TData, size
|
|||
using param_type = std::remove_const_t<std::remove_pointer_t<_TData>>;
|
||||
};
|
||||
|
||||
template <auto _Func> float wrap_dot_func(const void * src0, const void * src1, size_t count) {
|
||||
using param_type = typename dot_func_traits<decltype(_Func)>::param_type;
|
||||
return _Func(reinterpret_cast<const param_type *>(src0), reinterpret_cast<const param_type *>(src1), count);
|
||||
template <auto _DotFunc> float wrap_dot_func(const void * src0, const void * src1, size_t count) {
|
||||
using param_type = typename dot_func_traits<decltype(_DotFunc)>::param_type;
|
||||
|
||||
auto * src0_typed = reinterpret_cast<const param_type *>(src0);
|
||||
auto * src1_typed = reinterpret_cast<const param_type *>(src1);
|
||||
return _DotFunc(src0_typed, src1_typed, count);
|
||||
}
|
||||
|
||||
constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
|
||||
|
|
@ -422,6 +507,7 @@ constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
|
|||
wrap_dot_func<hexagon::vec_dot_product_f32_f32> },
|
||||
{ NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, nullptr, quantize_row_fp16,
|
||||
wrap_dot_func<hexagon::vec_dot_product_f16_f16> },
|
||||
{ NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false, nullptr, nullptr, nullptr },
|
||||
{ NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0,
|
||||
quantize_row_q8_0 },
|
||||
{ NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0,
|
||||
|
|
@ -436,6 +522,8 @@ static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F32].type == NPU_DATA_TYPE_F32,
|
|||
"kDeviceTypeTraits F32 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F16].type == NPU_DATA_TYPE_F16,
|
||||
"kDeviceTypeTraits F16 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_I32].type == NPU_DATA_TYPE_I32,
|
||||
"kDeviceTypeTraits I32 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q8_0].type == NPU_DATA_TYPE_Q8_0,
|
||||
"kDeviceTypeTraits Q8_0 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_0].type == NPU_DATA_TYPE_Q4_0,
|
||||
|
|
|
|||
|
|
@ -5,10 +5,12 @@
|
|||
|
||||
namespace hexagon {
|
||||
|
||||
using dequant_target_type = npu_device_fp16_t;
|
||||
|
||||
bool init_f16_f32_table(float * table, size_t count);
|
||||
|
||||
typedef void (*quantize_row_type)(const float * src, void * dst, size_t count, const float * f16_to_f32_table);
|
||||
typedef void (*dequantize_row_type)(const void * src, float * dst, size_t count, const float * f16_to_f32_table);
|
||||
typedef void (*quantize_row_type)(const float * src, void * dst, size_t count);
|
||||
typedef void (*dequantize_row_type)(const void * src, dequant_target_type * dst, size_t count);
|
||||
typedef float (*vec_dot_type)(const void * src0, const void * src1, size_t count);
|
||||
|
||||
struct device_type_traits {
|
||||
|
|
@ -29,15 +31,13 @@ inline bool is_quantized_type(npu_device_tensor_data_type type) {
|
|||
return get_type_traits(type).is_quantized;
|
||||
}
|
||||
|
||||
using dequantized_element_type = float;
|
||||
|
||||
inline size_t get_dequantized_row_size(const tensor * tensor) {
|
||||
if (!is_quantized_type(tensor->get_type())) {
|
||||
return tensor->get_nb(1); // for f32 and f16
|
||||
}
|
||||
|
||||
auto row_elems_count = tensor->get_ne(0);
|
||||
return row_elems_count * sizeof(dequantized_element_type); // currently only f32 is supported
|
||||
return row_elems_count * sizeof(dequant_target_type); // currently only f32 is supported
|
||||
}
|
||||
|
||||
inline const char * get_type_name(npu_device_tensor_data_type type) {
|
||||
|
|
@ -77,14 +77,14 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx) {
|
|||
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(sub_prefix) \
|
||||
hexagon::npu_sub_process_scoped_timer<decltype(__npu_op_timer_##sub_prefix)::kBufferCount, 0> \
|
||||
__npu_op_sub_timer##sub_prefix(__npu_op_timer_##sub_prefix, #sub_prefix)
|
||||
__npu_op_sub_timer##sub_prefix(__npu_op_timer_##sub_prefix, #sub_prefix)
|
||||
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(op, tidx, tracker_name) \
|
||||
auto __npu_op_timer_##tracker_name = hexagon::make_scoped_op_perf_timer(op, tidx)
|
||||
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(tracker_name, idx, sub_prefix) \
|
||||
hexagon::npu_sub_process_scoped_timer<decltype(__npu_op_timer_##tracker_name)::kBufferCount, idx> \
|
||||
__npu_op_sub_timer##sub_prefix(__npu_op_timer_##tracker_name, #sub_prefix)
|
||||
__npu_op_sub_timer##sub_prefix(__npu_op_timer_##tracker_name, #sub_prefix)
|
||||
|
||||
#else
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(op, tidx) ((void) 0)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ inline constexpr const char * op_get_name(npu_device_tensor_op op) {
|
|||
return "RMS_NORM";
|
||||
case NPU_OP_FLASH_ATTN:
|
||||
return "FLASH_ATTN_EXT";
|
||||
case NPU_OP_ROPE:
|
||||
return "ROPE";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
|
@ -64,6 +66,20 @@ inline bool is_transposed_or_permuted(const npu_device_nb_type & nb) {
|
|||
return (nb[0] > nb[1]) || (nb[1] > nb[2]) || (nb[2] > nb[3]);
|
||||
}
|
||||
|
||||
inline bool is_same_shape(const npu_device_ne_type & src, const npu_device_ne_type & dst) {
|
||||
for (size_t i = 0; i < DEVICE_TENSOR_MAX_DIMS; ++i) {
|
||||
if (src[i] != dst[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool is_same_shape(const npu_device_tensor_spec & src, const npu_device_tensor_spec & dst) {
|
||||
return is_same_shape(src.ne, dst.ne);
|
||||
}
|
||||
|
||||
class power_utils {
|
||||
public:
|
||||
power_utils() {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
#include "vec_ops.hpp"
|
||||
|
||||
#include <HTP/core/intrinsics.h>
|
||||
|
||||
#include "util.hpp"
|
||||
|
||||
namespace {
|
||||
|
|
@ -100,15 +98,20 @@ inline float vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * sr
|
|||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_Vector curr0_lo = src0_vec_ptr[0];
|
||||
HVX_Vector curr0_hi = src0_vec_ptr[1];
|
||||
HVX_Vector curr1_lo = src1_vec_ptr[0];
|
||||
HVX_Vector curr1_hi = src1_vec_ptr[1];
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(curr0_lo, curr1_lo), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(curr0_hi, curr1_hi), sum1);
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr0), Q6_V_lo_W(curr1)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr0), Q6_V_hi_W(curr1)), sum1);
|
||||
}
|
||||
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_Vector curr1 = src1_vec_ptr[0];
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(curr0, curr1), sum0);
|
||||
}
|
||||
|
||||
return _ReduceFunc(_AddFunc(sum0, sum1));
|
||||
|
|
@ -130,27 +133,189 @@ inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) {
|
|||
return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1, HVX_VectorPair (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
float (*_ReduceFunc)(HVX_Vector)>
|
||||
inline float vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
||||
static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2,
|
||||
"Element size mismatch: _TElem1 must be twice the size of _TElem0");
|
||||
static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0,
|
||||
"Element size mismatch: _TElem1 must be a multiple of _TElem0");
|
||||
|
||||
constexpr const size_t kElementsPerVector0 = hexagon::kBytesPerVector / sizeof(_TElem0);
|
||||
constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1);
|
||||
|
||||
constexpr const __fp16 kOne = 1.0f;
|
||||
const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast<const uint16_t &>(kOne));
|
||||
|
||||
const _TElem0 * const src0_ptr_end = src0 + count;
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1;
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
while (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
HVX_VectorPair s0_pair = _ExpandFunc(s0, kOneV);
|
||||
prev0 = curr0;
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
src0_vec_ptr++;
|
||||
src1_vec_ptr += 2;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), l1), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(s0_pair), h1), sum1);
|
||||
}
|
||||
|
||||
sum = _AddFunc(sum0, sum1);
|
||||
const size_t leftover1 = count % kElementsPerVector1;
|
||||
if ((src1_vec_ptr_end - ((HVX_Vector *) src1)) > 0) {
|
||||
// handle the last vector
|
||||
const bool should_fetch_src0 =
|
||||
reinterpret_cast<const _TElem0 *>(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end;
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_VectorPair s0_pair = _ExpandFunc(s0, kOneV);
|
||||
|
||||
const bool has_remaining_src1_vector = src1_vec_ptr_end - src1_vec_ptr > 0;
|
||||
if (has_remaining_src1_vector) {
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev1 = curr1;
|
||||
|
||||
// should_handle_last_vector will be always true here
|
||||
sum = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), s1), sum);
|
||||
}
|
||||
|
||||
bool should_fetch_src1 = leftover1 != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
||||
sum = _AddFunc(_MpyFunc(has_remaining_src1_vector ? Q6_V_hi_W(s0_pair) : Q6_V_lo_W(s0_pair), s1), sum);
|
||||
}
|
||||
|
||||
const size_t leftover0 = count % kElementsPerVector0;
|
||||
const size_t leftover_bytes1 = leftover1 * sizeof(_TElem1);
|
||||
if (leftover1 > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr0 =
|
||||
reinterpret_cast<const _TElem0 *>(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end ? *src0_vec_ptr : prev0;
|
||||
HVX_Vector curr1 = (leftover_bytes1 + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
HVX_VectorPair curr0_pair = _ExpandFunc(curr0, kOneV);
|
||||
|
||||
curr0 = leftover1 == leftover0 ? Q6_V_lo_W(curr0_pair) : Q6_V_hi_W(curr0_pair);
|
||||
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes1), sum);
|
||||
}
|
||||
|
||||
return _ReduceFunc(sum);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1, HVX_VectorPair (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
float (*_ReduceFunc)(HVX_Vector)>
|
||||
inline float vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
||||
static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2,
|
||||
"Element size mismatch: _TElem1 must be twice the size of _TElem0");
|
||||
static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0,
|
||||
"Element size mismatch: _TElem1 must be a multiple of _TElem0");
|
||||
|
||||
constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1);
|
||||
|
||||
constexpr const __fp16 kOne = 1.0f;
|
||||
const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast<const uint16_t &>(kOne));
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1;
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
{
|
||||
HVX_Vector sum2 = Q6_V_vzero();
|
||||
HVX_Vector sum3 = Q6_V_vzero();
|
||||
|
||||
while (src1_vec_ptr_end - src1_vec_ptr > 3) {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr10 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
HVX_VectorPair curr11 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[1];
|
||||
|
||||
HVX_VectorPair curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV);
|
||||
HVX_VectorPair curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV);
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 4;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr00), Q6_V_lo_W(curr10)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr00), Q6_V_hi_W(curr10)), sum1);
|
||||
sum2 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr01), Q6_V_lo_W(curr11)), sum2);
|
||||
sum3 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr01), Q6_V_hi_W(curr11)), sum3);
|
||||
}
|
||||
|
||||
sum0 = _AddFunc(sum0, sum2);
|
||||
sum1 = _AddFunc(sum1, sum3);
|
||||
}
|
||||
|
||||
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_VectorPair s0_pair = _ExpandFunc(curr0, kOneV);
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), Q6_V_lo_W(curr1)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(s0_pair), Q6_V_hi_W(curr1)), sum1);
|
||||
}
|
||||
|
||||
return _ReduceFunc(_AddFunc(sum0, sum1));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_impl<float, vec_mpy_qf32, vec_add_qf32, hexagon::vec_reduction_qf32_f32>(src0, src1, count);
|
||||
return vec_dot_product_impl<float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_aligned_impl<float, vec_mpy_qf32, vec_add_qf32, hexagon::vec_reduction_qf32_f32>(src0, src1,
|
||||
count);
|
||||
return vec_dot_product_aligned_impl<float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
||||
return vec_dot_product_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, hexagon::vec_reduction_qf16_f32>(
|
||||
src0, src1, count);
|
||||
return vec_dot_product_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(src0, src1,
|
||||
count);
|
||||
}
|
||||
|
||||
float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
||||
return vec_dot_product_aligned_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, hexagon::vec_reduction_qf16_f32>(
|
||||
return vec_dot_product_aligned_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
|
||||
vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
|
||||
vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <hexagon_types.h>
|
||||
#include <HTP/core/intrinsics.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
|
|
@ -12,10 +11,22 @@ namespace hexagon {
|
|||
constexpr const size_t kBytesPerVector = sizeof(HVX_Vector); // 128 for v73
|
||||
constexpr const size_t kAlignMask = kBytesPerVector - 1;
|
||||
|
||||
inline size_t get_aligned_size(size_t size) {
|
||||
return (size + kAlignMask) & ~kAlignMask;
|
||||
}
|
||||
|
||||
inline size_t unaligned_bytes(const void * addr) {
|
||||
return ((size_t) addr) & kAlignMask;
|
||||
}
|
||||
|
||||
template <typename _TyData> inline const _TyData * align_down(const _TyData * addr) {
|
||||
return reinterpret_cast<const _TyData *>(reinterpret_cast<const uint8_t *>(addr) - unaligned_bytes(addr));
|
||||
}
|
||||
|
||||
inline size_t bytes_to_vector_boundary(const void * addr) {
|
||||
return kBytesPerVector - unaligned_bytes(addr);
|
||||
}
|
||||
|
||||
inline bool is_addr_aligned(const void * addr) {
|
||||
return unaligned_bytes(addr) == 0;
|
||||
}
|
||||
|
|
@ -109,6 +120,50 @@ inline HVX_VectorPair qhmath_hvx_vqf32_convert_vqf16(HVX_Vector vxl) {
|
|||
return Q6_W_vcombine_VV(vxh_w, vxl_w);
|
||||
}
|
||||
|
||||
template <uint32_t _TyBytes> inline void q6op_vstu_variable_ARV(void * addr, HVX_Vector vin) {
|
||||
vin = Q6_V_vlalign_VVR(vin, vin, (size_t) addr); //rotate as needed.
|
||||
uint32_t left_off = unaligned_bytes(addr);
|
||||
uint32_t right_off = left_off + _TyBytes;
|
||||
HVX_VectorPred qL_not = Q6_Q_vsetq_R((size_t) addr);
|
||||
HVX_VectorPred qR = Q6_Q_vsetq2_R(right_off);
|
||||
if (right_off > 128) {
|
||||
Q6_vmaskedstoreq_QAV(qR, (HVX_Vector *) addr + 1, vin);
|
||||
qR = Q6_Q_vcmp_eq_VbVb(vin, vin); // all 1's
|
||||
}
|
||||
qL_not = Q6_Q_or_QQn(qL_not, qR);
|
||||
Q6_vmaskedstorenq_QAV(qL_not, (HVX_Vector *) addr, vin);
|
||||
}
|
||||
|
||||
template <uint32_t _TyBytes> inline void q6op_vstu_variable_aligned(void * addr, HVX_Vector vin) {
|
||||
HVX_VectorPred qR = Q6_Q_vsetq2_R(_TyBytes);
|
||||
Q6_vmaskedstorenq_QAV(qR, (HVX_Vector *) addr, vin);
|
||||
}
|
||||
|
||||
inline void q6op_vstu_variable_ARV(void * addr, int n, HVX_Vector vin) {
|
||||
vin = Q6_V_vlalign_VVR(vin, vin, (size_t) addr); //rotate as needed.
|
||||
unsigned left_off = unaligned_bytes(addr);
|
||||
unsigned right_off = left_off + n;
|
||||
HVX_VectorPred qL_not = Q6_Q_vsetq_R((size_t) addr);
|
||||
HVX_VectorPred qR = Q6_Q_vsetq2_R(right_off);
|
||||
if (right_off > 128) {
|
||||
Q6_vmaskedstoreq_QAV(qR, (HVX_Vector *) addr + 1, vin);
|
||||
qR = Q6_Q_vcmp_eq_VbVb(vin, vin); // all 1's
|
||||
}
|
||||
qL_not = Q6_Q_or_QQn(qL_not, qR);
|
||||
Q6_vmaskedstorenq_QAV(qL_not, (HVX_Vector *) addr, vin);
|
||||
}
|
||||
|
||||
inline HVX_VectorPair hvx_vqf32_convert_vhf(HVX_Vector vxl) {
|
||||
return qhmath_hvx_vqf32_convert_vqf16(qhmath_hvx_vqf16_convert_vhf(vxl));
|
||||
}
|
||||
|
||||
inline HVX_VectorPair hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) {
|
||||
HVX_VectorPair res = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vxl), one);
|
||||
HVX_Vector vxl_w = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(res));
|
||||
HVX_Vector vxh_w = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(res));
|
||||
return Q6_W_vcombine_VV(vxh_w, vxl_w);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) {
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
static_assert(kFloatsPerVector == 32 || kFloatsPerVector == 16, "kFloatsPerVector should be 16 or 32");
|
||||
|
|
@ -130,7 +185,7 @@ inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) {
|
|||
return sums;
|
||||
}
|
||||
|
||||
inline float vec_reduction_qf32_f32(HVX_Vector sums) {
|
||||
inline float vec_reduction_f32_qf32(HVX_Vector sums) {
|
||||
return get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec_reduction_qf32(sums)));
|
||||
}
|
||||
|
||||
|
|
@ -265,10 +320,41 @@ inline void vec_mad_f16(const npu_device_fp16_t * src, float scale, npu_device_f
|
|||
vec_scale_impl<hvx_vec_mad_f16_f16, hvx_scale_f16, npu_device_fp16_t>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1>
|
||||
inline bool is_dot_product_aligned(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) <= sizeof(_TElem1), "src0 should be smaller than src1");
|
||||
|
||||
if (!hexagon::is_addr_aligned(src0) || !hexagon::is_addr_aligned(src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (count % (hexagon::kBytesPerVector / sizeof(_TElem0)) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count);
|
||||
float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count);
|
||||
|
||||
inline bool is_f32_f32_dot_product_aligned(const float * src0, const float * src1, size_t count) {
|
||||
return is_dot_product_aligned<float, float>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count);
|
||||
float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count);
|
||||
|
||||
inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
return is_dot_product_aligned<npu_device_fp16_t, npu_device_fp16_t>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count);
|
||||
float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count);
|
||||
|
||||
inline bool is_f16_f32_dot_product_aligned(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
return is_dot_product_aligned<npu_device_fp16_t, float>(src0, src1, count);
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -29,8 +29,7 @@ bool host_graph::update(ggml_cgraph * cgraph) {
|
|||
return false;
|
||||
}
|
||||
|
||||
LOG_DEBUG("[%p]host_graph::update started\n", (void *) this);
|
||||
|
||||
PROFILER_LOG_DEBUG("[%p]host_graph::update started\n", (void *) this);
|
||||
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]update, handle(%p)", (void *) this, (void *) _graph_handle);
|
||||
|
||||
_tensor_handles.clear();
|
||||
|
|
@ -57,10 +56,11 @@ bool host_graph::update(ggml_cgraph * cgraph) {
|
|||
|
||||
_tensor_handles.push_back(tensor_obj->get_device_tensor_handle());
|
||||
_tensor_update_configs.push_back(tensor_obj->update_hosts_params_only(node));
|
||||
LOG_DEBUG("node[%d]%s(%s), addr: %p, type: %s, dims: %ldx%ldx%ldx%ld, tensor_handle: %p\n", i,
|
||||
ggml_get_name(node), ggml_op_desc(node), (void *) node, ggml_type_name(node->type),
|
||||
(long) tensor_obj->get_ne(0), (long) tensor_obj->get_ne(1), (long) tensor_obj->get_ne(2),
|
||||
(long) tensor_obj->get_ne(3), (void *) tensor_obj->get_device_tensor_handle());
|
||||
|
||||
PROFILER_LOG_DEBUG("node[%d]%s(%s), addr(%p), %s_%ldx%ldx%ldx%ld, handle(%p)\n", i, ggml_get_name(node),
|
||||
ggml_op_desc(node), (void *) tensor_obj, ggml_type_name(node->type),
|
||||
(long) tensor_obj->get_ne(0), (long) tensor_obj->get_ne(1), (long) tensor_obj->get_ne(2),
|
||||
(long) tensor_obj->get_ne(3), (void *) tensor_obj->get_device_tensor_handle());
|
||||
}
|
||||
|
||||
GGML_ASSERT(_tensor_handles.size() == _tensor_update_configs.size());
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) {
|
|||
#ifndef NDEBUG
|
||||
auto * src0_type = i ? ggml_type_name(op->src[0]->type) : "null";
|
||||
auto * src1_type = (i > 1) ? ggml_type_name(op->src[1]->type) : "null";
|
||||
LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", get_name(), ggml_op_name(op->op),
|
||||
LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", get_name(), ggml_op_desc(op),
|
||||
ggml_type_name(op->type), src0_type, src1_type, ret, supported);
|
||||
#endif
|
||||
return false;
|
||||
|
|
@ -275,16 +275,16 @@ bool npu_device::supports_op(const ggml_tensor * op) {
|
|||
if (op->op != GGML_OP_NONE && op->op != GGML_OP_VIEW && op->op != GGML_OP_RESHAPE &&
|
||||
op->op != GGML_OP_PERMUTE) {
|
||||
_supported_op++;
|
||||
LOG_DEBUG("[%s][%s]supported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_name(op->op),
|
||||
op_desc, _supported_op.load(), _unsupported_op.load());
|
||||
LOG_DEBUG("[%s][%s][%s]supported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op),
|
||||
ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
_unsupported_op++;
|
||||
LOG_DEBUG("[%s][%s]unsupported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_name(op->op), op_desc,
|
||||
_supported_op.load(), _unsupported_op.load());
|
||||
LOG_DEBUG("[%s][%s][%s]unsupported, %s, supported/unsupported: %u/%u\n", get_name(), ggml_op_desc(op),
|
||||
ggml_get_name(op), op_desc, _supported_op.load(), _unsupported_op.load());
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -48,10 +48,14 @@ class host_tensor {
|
|||
|
||||
tensor->extra = this;
|
||||
_ggml_tensor = tensor;
|
||||
LOG_DEBUG("host_tensor(%p), ggml_tensor(%s[%ldx%ldx%ldx%ld], nb[%ld][%ld][%ld][%ld], %s, %p), handle(%p)\n",
|
||||
(void *) this, tensor->name, (long) tensor->ne[0], (long) tensor->ne[1], (long) tensor->ne[2],
|
||||
(long) tensor->ne[3], (long) tensor->nb[0], (long) tensor->nb[1], (long) tensor->nb[2],
|
||||
(long) tensor->nb[3], ggml_type_name(tensor->type), (void *) tensor, (void *) _device_tensor_handle);
|
||||
|
||||
#ifndef NDEBUG
|
||||
{
|
||||
char desc[1024];
|
||||
get_desc(desc, sizeof(desc));
|
||||
LOG_DEBUG("host_tensor(%s)\n", desc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
~host_tensor() {
|
||||
|
|
@ -99,7 +103,11 @@ class host_tensor {
|
|||
auto * ggml_src = _ggml_tensor->src[j];
|
||||
auto * src = host_tensor::from_ggml_tensor(ggml_src);
|
||||
src_tensor_handles[j] = src->get_device_tensor_handle();
|
||||
LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p(%s)\n", (void *) this, j, (void *) src, ggml_src->name);
|
||||
#ifndef NDEBUG
|
||||
char desc[1024];
|
||||
src->get_desc(desc, sizeof(desc));
|
||||
LOG_DEBUG("host_tensor(%p) set_src[%zu]: (%s)\n", (void *) this, j, desc);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (memcmp(_info_update.src_handles, src_tensor_handles, sizeof(_info_update.src_handles)) != 0) {
|
||||
|
|
@ -136,7 +144,11 @@ class host_tensor {
|
|||
auto * ggml_src = _ggml_tensor->src[j];
|
||||
auto * src = host_tensor::from_ggml_tensor(ggml_src);
|
||||
_info_update.src_handles[j] = src->get_device_tensor_handle();
|
||||
LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p(%s)\n", (void *) this, j, (void *) src, ggml_src->name);
|
||||
#ifndef NDEBUG
|
||||
char desc[1024];
|
||||
src->get_desc(desc, sizeof(desc));
|
||||
LOG_DEBUG("host_tensor(%p) set_src[%zu]: (%s)\n", (void *) this, j, desc);
|
||||
#endif
|
||||
}
|
||||
|
||||
LOG_DEBUG("host_tensor(%p) update_params, op: %s, params: [%x, %x, %x, %x]\n", (void *) this,
|
||||
|
|
@ -156,6 +168,15 @@ class host_tensor {
|
|||
return _info.ne[index];
|
||||
}
|
||||
|
||||
int get_desc(char * buffer, size_t size) const {
|
||||
return snprintf(buffer, size, "%s[%ldx%ldx%ldx%ld], nb[%ld,%ld,%ld,%ld], %s, addr: %p, ggml: %p, handle:%p",
|
||||
_ggml_tensor->name, (long) _ggml_tensor->ne[0], (long) _ggml_tensor->ne[1],
|
||||
(long) _ggml_tensor->ne[2], (long) _ggml_tensor->ne[3], (long) _ggml_tensor->nb[0],
|
||||
(long) _ggml_tensor->nb[1], (long) _ggml_tensor->nb[2], (long) _ggml_tensor->nb[3],
|
||||
ggml_type_name(_ggml_tensor->type), (void *) this, (void *) _ggml_tensor,
|
||||
(void *) _device_tensor_handle);
|
||||
}
|
||||
|
||||
private:
|
||||
remote_handle64 _device_handle = 0;
|
||||
npu_device_tensor_handle_t _device_tensor_handle = 0;
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ static_assert(QUANT_K_SCALE_SIZE == K_SCALE_SIZE, "QUANT_K_SCALE_SIZE size misma
|
|||
static_assert(QUANT_K_BLOCK_SIZE == QK_K, "QUANT_K_BLOCK_SIZE size mismatch");
|
||||
static_assert(QUANT_BLOCK_SIZE == QK4_0, "QUANT_BLOCK_SIZE size mismatch");
|
||||
|
||||
static_assert(NPU_ROPE_TYPE_NEOX == GGML_ROPE_TYPE_NEOX, "NPU_ROPE_TYPE_NEOX mismatch");
|
||||
static_assert(NPU_ROPE_TYPE_MROPE == GGML_ROPE_TYPE_MROPE, "NPU_ROPE_TYPE_MROPE mismatch");
|
||||
static_assert(NPU_ROPE_TYPE_VISION == GGML_ROPE_TYPE_VISION, "NPU_ROPE_TYPE_VISION mismatch");
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
enum npu_device_tensor_op op_to_npu_op(ggml_op op) {
|
||||
|
|
@ -29,6 +33,8 @@ enum npu_device_tensor_op op_to_npu_op(ggml_op op) {
|
|||
return NPU_OP_RMS_NORM;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return NPU_OP_FLASH_ATTN;
|
||||
case GGML_OP_ROPE:
|
||||
return NPU_OP_ROPE;
|
||||
default:
|
||||
return NPU_OP_COUNT;
|
||||
}
|
||||
|
|
@ -48,6 +54,8 @@ const char * get_npu_op_desc(enum npu_device_tensor_op op) {
|
|||
return ggml_op_name(GGML_OP_RMS_NORM);
|
||||
case NPU_OP_FLASH_ATTN:
|
||||
return ggml_op_name(GGML_OP_FLASH_ATTN_EXT);
|
||||
case NPU_OP_ROPE:
|
||||
return ggml_op_name(GGML_OP_ROPE);
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
|
@ -59,6 +67,8 @@ enum npu_device_tensor_data_type type_to_npu_type(ggml_type type) {
|
|||
return NPU_DATA_TYPE_F32;
|
||||
case GGML_TYPE_F16:
|
||||
return NPU_DATA_TYPE_F16;
|
||||
case GGML_TYPE_I32:
|
||||
return NPU_DATA_TYPE_I32;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return NPU_DATA_TYPE_Q4_K;
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
|
|
|||
|
|
@ -4,11 +4,15 @@
|
|||
|
||||
const uint32_t DEVICE_TENSOR_MAX_DIMS = 4;
|
||||
const uint32_t DEVICE_TENSOR_MAX_SRC = 4;
|
||||
const uint32_t DEVICE_TENSOR_MAX_OP_PARAMS = 4;
|
||||
const uint32_t DEVICE_TENSOR_MAX_OP_PARAMS = 16;
|
||||
const uint32_t QUANT_BLOCK_SIZE = 32;
|
||||
const uint32_t QUANT_K_BLOCK_SIZE = 256;
|
||||
const uint32_t QUANT_K_SCALE_SIZE = 12;
|
||||
|
||||
const uint32_t NPU_ROPE_TYPE_NEOX = 2;
|
||||
const uint32_t NPU_ROPE_TYPE_MROPE = 8;
|
||||
const uint32_t NPU_ROPE_TYPE_VISION = 24;
|
||||
|
||||
interface npu_device : remote_handle64{
|
||||
|
||||
typedef int64_t ne_type[DEVICE_TENSOR_MAX_DIMS];
|
||||
|
|
@ -42,12 +46,14 @@ interface npu_device : remote_handle64{
|
|||
NPU_OP_MUL,
|
||||
NPU_OP_RMS_NORM,
|
||||
NPU_OP_FLASH_ATTN,
|
||||
NPU_OP_ROPE,
|
||||
NPU_OP_COUNT
|
||||
};
|
||||
|
||||
enum tensor_data_type {
|
||||
NPU_DATA_TYPE_F32,
|
||||
NPU_DATA_TYPE_F16,
|
||||
NPU_DATA_TYPE_I32,
|
||||
NPU_DATA_TYPE_Q8_0,
|
||||
NPU_DATA_TYPE_Q4_0,
|
||||
NPU_DATA_TYPE_Q4_K,
|
||||
|
|
|
|||
|
|
@ -56,6 +56,8 @@ inline scoped_timer make_scope_perf_timer(const char * format, ...) {
|
|||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
# define SCOPED_PERFORMANCE_TRACKER(fmt, ...) \
|
||||
auto __scoped_timer_##__LINE__ = profiler::make_scope_perf_timer(fmt, __VA_ARGS__)
|
||||
# define PROFILER_LOG_DEBUG(fmt, ...) GGML_LOG_INFO("[profiler]" fmt, __VA_ARGS__)
|
||||
#else
|
||||
# define SCOPED_PERFORMANCE_TRACKER(fmt, ...) ((void) 0)
|
||||
# define PROFILER_LOG_DEBUG(...) ((void) 0)
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue