feat: perf opt part5 (#52)
* rename * Refactor vector operations in vec_op_impl and vec_dot_product_impl for improved clarity and performance * wip * Enhance vector copy functions for improved performance and clarity in vec_ops.hpp * wip * wip * wip * Optimize vector dot product implementations for enhanced performance and efficiency * Enhance flash attention implementation and type traits for improved vector operations and alignment checks # Conflicts: # ggml/src/ggml-qnn/npu/device/type_traits.cpp * remove align * wip * Enhance vector dot product implementation for improved performance by adding parallel processing for multiple vector pairs * Revert "Enhance vector dot product implementation for improved performance by adding parallel processing for multiple vector pairs" This reverts commit 78cc24ed2285002ca29d6189fa61ba4ce24f8d16. * Enhance flash attention implementation with type checks for tensor data types and improved constexpr usage * wip * opt mask calc * Revert "opt mask calc" This reverts commit bb1840876692a11511d5ab7828b8a707402e30b9. * wip * opt mul mat caching logic to add dst cache * Revert "opt mul mat caching logic to add dst cache" This reverts commit ab442fa9f763b3873c929936e4cb739cb1c83850. * wip * Refactor matrix multiplication implementation to include vector conversion and performance tracking * wip * wip * wip * create vec_ops.inl for more aggressive compiler inline * wip * refactor vector dot product implementations for improved readability and performance * refactor vector conversion functions to use HVX_Vector_Dual for improved clarity and consistency * wip * wip * wip * implement row size caching logic and enhance type traits for F32 support * refactor matrix multiplication functions to improve caching logic and simplify tensor alignment handling * add vector zeroing functions for F32 and F16 types to optimize memory initialization * Revert "add vector zeroing functions for F32 and F16 types to optimize memory initialization" This reverts commit e374326dc74d049e6603e393ade418d9ef2b83f3. * wip * refactor alignment checks in dot product function to handle null pointers * wip * refactor load_block_generic and related functions for improved alignment handling * wip * refactor flash attention implementation and introduce type-erased dot function for improved type handling * refactor dot product implementations for improved loop handling and clarity * refactor thread_pool constructor to pre-allocate VTCM cache for each thread * Revert "refactor thread_pool constructor to pre-allocate VTCM cache for each thread" This reverts commit 00cdd3fa88d909feef44ddaa42095274b7627685. * wip * opt interfaces for tensor cleanup * refactor mul_mat_impl to use aligned size for src0 row calculation * refactor: update dequantized_row_size logic and add size alignment checks for tensors * wip * wip * refactor: replace raw pointer initialization with invalid handle constants for better clarity * wip
This commit is contained in:
parent
fc45ad51d2
commit
2cd429ca75
|
|
@ -4,7 +4,6 @@
|
|||
#include <hexagon_types.h>
|
||||
|
||||
#include <memory>
|
||||
#include <new>
|
||||
|
||||
#include "graph.hpp"
|
||||
#include "hexagon_npu.h"
|
||||
|
|
@ -69,20 +68,28 @@ struct npu_device_context {
|
|||
}
|
||||
};
|
||||
|
||||
inline hexagon::tensor * tensor_from_handle(npu_device_graph_handle_t h) {
|
||||
inline hexagon::tensor * tensor_from_handle(npu_device_tensor_handle_t h) {
|
||||
if (h == npu_device_INVALID_DEVICE_TENSOR_HANDLE) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return reinterpret_cast<hexagon::tensor *>(h);
|
||||
}
|
||||
|
||||
inline npu_device_graph_handle_t tensor_to_handle(hexagon::tensor * tensor) {
|
||||
return reinterpret_cast<npu_device_graph_handle_t>(tensor);
|
||||
inline npu_device_tensor_handle_t tensor_to_handle(hexagon::tensor * tensor) {
|
||||
return reinterpret_cast<npu_device_tensor_handle_t>(tensor);
|
||||
}
|
||||
|
||||
inline hexagon::graph * graph_from_handle(npu_device_tensor_handle_t h) {
|
||||
inline hexagon::graph * graph_from_handle(npu_device_graph_handle_t h) {
|
||||
if (h == npu_device_INVALID_DEVICE_GRAPH_HANDLE) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return reinterpret_cast<hexagon::graph *>(h);
|
||||
}
|
||||
|
||||
inline npu_device_tensor_handle_t graph_to_handle(hexagon::graph * graph) {
|
||||
return reinterpret_cast<npu_device_tensor_handle_t>(graph);
|
||||
inline npu_device_graph_handle_t graph_to_handle(hexagon::graph * graph) {
|
||||
return reinterpret_cast<npu_device_graph_handle_t>(graph);
|
||||
}
|
||||
|
||||
inline npu_device_context * device_context_from_handle(remote_handle64 h) {
|
||||
|
|
@ -93,12 +100,7 @@ inline npu_device_context * device_context_from_handle(remote_handle64 h) {
|
|||
|
||||
int npu_device_open(const char * uri, remote_handle64 * h) {
|
||||
// TODO: should we have a device context here?
|
||||
auto * context = new (std::nothrow) npu_device_context();
|
||||
if (!context) {
|
||||
DEVICE_LOG_ERROR("Failed to allocate memory for the npu_device_context");
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
auto * context = new npu_device_context();
|
||||
if (!context->init()) {
|
||||
DEVICE_LOG_ERROR("Failed to initialize npu_device_context");
|
||||
delete context;
|
||||
|
|
@ -144,12 +146,7 @@ AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op
|
|||
AEEResult npu_device_tensor_init(remote_handle64 _h, const npu_device_tensor_config * info,
|
||||
npu_device_tensor_handle_t * tensor_handle) {
|
||||
NPU_UNUSED(_h);
|
||||
auto * tensor = new (std::nothrow) hexagon::tensor(*info);
|
||||
if (!tensor) {
|
||||
DEVICE_LOG_ERROR("Failed to allocate memory for the tensor");
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
auto * tensor = new hexagon::tensor(*info);
|
||||
*tensor_handle = tensor_to_handle(tensor);
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
|
@ -177,13 +174,29 @@ AEEResult npu_device_tensor_free(remote_handle64 _h, npu_device_tensor_handle_t
|
|||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t * graph_handle) {
|
||||
AEEResult npu_device_tensors_free(remote_handle64 _h, const npu_device_tensor_handle_t * tensor_handles,
|
||||
int tensor_handlesLen) {
|
||||
NPU_UNUSED(_h);
|
||||
auto * graph = new (std::nothrow) hexagon::graph();
|
||||
if (!graph) {
|
||||
return AEE_ENOMEMORY;
|
||||
if (!tensor_handles || tensor_handlesLen < 0) {
|
||||
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments");
|
||||
return AEE_EINVARGS;
|
||||
}
|
||||
|
||||
for (int i = 0; i < tensor_handlesLen; ++i) {
|
||||
auto * tensor = tensor_from_handle(tensor_handles[i]);
|
||||
if (tensor) {
|
||||
delete tensor;
|
||||
} else {
|
||||
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid tensor handle at index %d", i);
|
||||
}
|
||||
}
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t * graph_handle) {
|
||||
NPU_UNUSED(_h);
|
||||
auto * graph = new hexagon::graph();
|
||||
*graph_handle = graph_to_handle(graph);
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,10 +13,19 @@ inline float f16_to_f32(const npu_device_fp16_t src) {
|
|||
}
|
||||
|
||||
// From: ggml/src/ggml-cpu/ops.cpp
|
||||
template <bool _IsKvF16>
|
||||
void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k,
|
||||
const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) {
|
||||
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");
|
||||
|
||||
constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32;
|
||||
|
||||
if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n",
|
||||
hexagon::get_type_name(k->get_type()), hexagon::get_type_name(v->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
float scale = out->get_op_param<float>(0);
|
||||
const float max_bias = out->get_op_param<float>(1);
|
||||
const float logit_softcap = out->get_op_param<float>(2);
|
||||
|
|
@ -37,9 +46,11 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const auto q_to_vec_dot = hexagon::get_type_traits(k->get_type()).from_float; // TODO: fix this
|
||||
const auto kq_vec_dot = hexagon::get_type_traits(k->get_type()).vec_dot;
|
||||
if (!q_to_vec_dot || !kq_vec_dot) {
|
||||
const auto & k_type_traits = hexagon::get_type_traits(kKvDataType);
|
||||
const auto q_to_vec_dot = k_type_traits.from_float;
|
||||
constexpr const auto kq_vec_dot = _IsKvF16 ? hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16> :
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>;
|
||||
if (!q_to_vec_dot) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n");
|
||||
return;
|
||||
}
|
||||
|
|
@ -50,12 +61,12 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
const auto DK = k->get_ne(0);
|
||||
const auto DV = v->get_ne(0);
|
||||
const auto row_bytes_q = q->get_ne(0) * hexagon::get_type_traits(q->get_type()).type_size;
|
||||
const auto row_bytes_k = DK * hexagon::get_type_traits(k->get_type()).type_size;
|
||||
const auto row_bytes_k = DK * k_type_traits.type_size;
|
||||
const auto row_bytes_v = DV * hexagon::get_type_traits(v->get_type()).type_size;
|
||||
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
const auto aligned_dk = (DK + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector;
|
||||
const auto aligned_dv = (DV + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector;
|
||||
constexpr const size_t kFloatsPerVectorPair = hexagon::kBytesPerVector * 2 / sizeof(float);
|
||||
const auto aligned_dk = (DK + kFloatsPerVectorPair - 1) / kFloatsPerVectorPair * kFloatsPerVectorPair;
|
||||
const auto aligned_dv = (DV + kFloatsPerVectorPair - 1) / kFloatsPerVectorPair * kFloatsPerVectorPair;
|
||||
size_t total_cache_size = sizeof(float) * (aligned_dk + 2 * aligned_dv);
|
||||
auto * cache_ptr = params->get_vtcm_cache(total_cache_size);
|
||||
if (!cache_ptr) {
|
||||
|
|
@ -64,11 +75,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
}
|
||||
|
||||
// loop over n_batch and n_head
|
||||
const auto rows_per_batch = q->get_ne(2) * q->get_ne(1);
|
||||
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
|
||||
const bool is_v_f16 =
|
||||
v->get_type() == NPU_DATA_TYPE_F16; // check if V is in FP16 format, otherwise it is in FP32 format
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
constexpr bool is_v_f16 = _IsKvF16; // check if V is in FP16 format, otherwise it is in FP32 format
|
||||
const auto rows_per_batch = q->get_ne(2) * q->get_ne(1);
|
||||
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
|
|
@ -80,6 +90,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
const uint8_t * k_ptr = k->get_read_buffer();
|
||||
const uint8_t * v_ptr = v->get_read_buffer();
|
||||
const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr;
|
||||
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
|
||||
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
|
||||
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
|
||||
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
|
||||
for (auto ir = start_end_row.first; ir < start_end_row.second; ++ir) {
|
||||
// q indices
|
||||
const auto iq3 = ir / rows_per_batch;
|
||||
|
|
@ -90,15 +104,13 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
const float slope =
|
||||
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
|
||||
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
|
||||
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
|
||||
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
|
||||
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
|
||||
const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3));
|
||||
hexagon::l2fetch_row(q_data, row_bytes_q);
|
||||
|
||||
if (is_v_f16) {
|
||||
if constexpr (is_v_f16) {
|
||||
memset(VKQ16, 0, DV * sizeof(npu_device_fp16_t));
|
||||
} else {
|
||||
memset(VKQ32, 0, DV * sizeof(float));
|
||||
|
|
@ -117,16 +129,13 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3));
|
||||
if (iq1 < q->get_ne(1) - 1) {
|
||||
hexagon::l2fetch_row(q_data + q->get_nb(1), row_bytes_q);
|
||||
}
|
||||
|
||||
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);
|
||||
|
||||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
const auto * k_plane_ptr = k_ptr + ik2 * k->get_nb(2) + ik3 * k->get_nb(3);
|
||||
const auto * v_plane_ptr = v_ptr + iv2 * v->get_nb(2) + iv3 * v->get_nb(3);
|
||||
for (int64_t ic = 0; ic < k->get_ne(1); ++ic) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 0, loop);
|
||||
float mv = mp ? (slope * f16_to_f32(mp[ic])) : 0.0f;
|
||||
|
|
@ -137,7 +146,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
float s = 0.f;
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 1, kq_dot);
|
||||
const auto * k_data = k_ptr + (ic * k->get_nb(1) + ik2 * k->get_nb(2) + ik3 * k->get_nb(3));
|
||||
const auto * k_data = k_plane_ptr + ic * k->get_nb(1);
|
||||
if (ic < k->get_ne(1) - 1) {
|
||||
hexagon::l2fetch_row(k_data + k->get_nb(1), row_bytes_k);
|
||||
}
|
||||
|
|
@ -156,12 +165,12 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
||||
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
||||
|
||||
const auto * v_data = v_ptr + (ic * v->get_nb(1) + iv2 * v->get_nb(2) + iv3 * v->get_nb(3));
|
||||
const auto * v_data = v_plane_ptr + ic * v->get_nb(1);
|
||||
if (ic < v->get_ne(1)) {
|
||||
hexagon::l2fetch_row(v_data, row_bytes_v);
|
||||
}
|
||||
|
||||
if (is_v_f16) {
|
||||
if constexpr (is_v_f16) {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
|
|
@ -201,7 +210,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
S = S * ms + vs; // scale and increment sum with partial sum
|
||||
}
|
||||
|
||||
if (is_v_f16) {
|
||||
if constexpr (is_v_f16) {
|
||||
// TODO: use a more efficient conversion
|
||||
for (int64_t d = 0; d < DV; ++d) {
|
||||
VKQ32[d] = f16_to_f32(VKQ16[d]);
|
||||
|
|
@ -218,7 +227,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
const int i3 = iq3;
|
||||
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1), VKQ32, out->get_nb(1));
|
||||
hexagon::vec_cpy_f32(
|
||||
reinterpret_cast<const float *>(VKQ32),
|
||||
reinterpret_cast<float *>(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1)),
|
||||
out->get_ne(0));
|
||||
}
|
||||
|
||||
out->release_write_buffer(); // mark the output tensor as modified
|
||||
|
|
@ -244,7 +256,11 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
|
|||
return false;
|
||||
}
|
||||
|
||||
flash_attn_impl(out, q, k, v, mask, params);
|
||||
if (k->get_type() == NPU_DATA_TYPE_F16) {
|
||||
flash_attn_impl<true>(out, q, k, v, mask, params);
|
||||
} else {
|
||||
flash_attn_impl<false>(out, q, k, v, mask, params);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,64 +12,10 @@
|
|||
|
||||
namespace {
|
||||
|
||||
template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector), typename _TyData>
|
||||
inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
|
||||
|
||||
HVX_Vector * iptr0 = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const iptr0_end = ((HVX_Vector *) src0) + (count / kElementsPerVector);
|
||||
HVX_Vector * iptr1 = ((HVX_Vector *) src1);
|
||||
HVX_Vector * optr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
|
||||
HVX_Vector prev0 = *iptr0++;
|
||||
HVX_Vector prev1 = *iptr1++;
|
||||
|
||||
while (iptr0 < iptr0_end) {
|
||||
HVX_Vector curr0 = *iptr0++;
|
||||
HVX_Vector curr1 = *iptr1++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
*optr++ = _OpIntrinsic(s0, s1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
if ((iptr0_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(iptr0);
|
||||
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(iptr1);
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *iptr0 : prev0;
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *iptr1 : prev1;
|
||||
iptr0 += should_fetch_src0 ? 1 : 0;
|
||||
iptr1 += should_fetch_src1 ? 1 : 0;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
*optr++ = _OpIntrinsic(s0, s1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
}
|
||||
|
||||
const size_t leftover_bytes = leftover * sizeof(_TyData);
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr0 =
|
||||
(leftover_bytes + hexagon::unaligned_bytes(iptr0) > hexagon::kBytesPerVector) ? *iptr0 : prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 =
|
||||
(leftover_bytes + hexagon::unaligned_bytes(iptr1) > hexagon::kBytesPerVector) ? *iptr1 : prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
hexagon::q6op_vstu_variable_ARV(optr, leftover_bytes, _OpIntrinsic(curr0, curr1));
|
||||
}
|
||||
}
|
||||
|
||||
template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector)>
|
||||
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
|
||||
inline void vec_op_f32_f32(const float * src0, const float * src1, size_t count, float * dst) {
|
||||
vec_op_impl<_OpIntrinsic, float>(src0, src1, count, dst);
|
||||
using namespace hexagon::vec;
|
||||
vec_trans_op_impl<_OpBinaryTransform, float>(src0, src1, count, dst);
|
||||
}
|
||||
|
||||
inline HVX_Vector vadd_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
|
|
@ -84,10 +30,11 @@ inline HVX_Vector vmul_f32_f32(HVX_Vector a, HVX_Vector b) {
|
|||
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b));
|
||||
}
|
||||
|
||||
template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector)>
|
||||
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
|
||||
inline void vec_op_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count,
|
||||
npu_device_fp16_t * dst) {
|
||||
vec_op_impl<_OpIntrinsic, npu_device_fp16_t>(src0, src1, count, dst);
|
||||
using namespace hexagon::vec;
|
||||
vec_trans_op_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, count, dst);
|
||||
}
|
||||
|
||||
inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) {
|
||||
|
|
@ -252,10 +199,10 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
|
|||
prev = curr;
|
||||
}
|
||||
|
||||
const size_t leftover_bytes = leftover * sizeof(float);
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr =
|
||||
const size_t leftover_bytes = leftover * sizeof(float);
|
||||
HVX_Vector curr =
|
||||
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
|
||||
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum,
|
||||
|
|
|
|||
|
|
@ -6,26 +6,37 @@
|
|||
|
||||
namespace {
|
||||
|
||||
template <typename T> struct get_data_type {};
|
||||
template <typename _T> struct get_data_type {};
|
||||
|
||||
template <typename _TyData0, typename _TyData1>
|
||||
struct get_data_type<float (*)(const _TyData0 *, const _TyData1 *, size_t)> {
|
||||
using data_type0 = _TyData0;
|
||||
using data_type1 = _TyData1;
|
||||
template <typename _TData0, typename _TData1>
|
||||
struct get_data_type<HVX_Vector (*)(const _TData0 *, const _TData1 *, size_t)> {
|
||||
using data_type0 = _TData0;
|
||||
using data_type1 = _TData1;
|
||||
};
|
||||
|
||||
template <auto _DotFunc, bool _IsQuantized>
|
||||
template <typename _TRet> struct convert_vector {};
|
||||
|
||||
template <> struct convert_vector<float> {
|
||||
static float convert(HVX_Vector vec) { return hexagon::get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec)); }
|
||||
};
|
||||
|
||||
template <> struct convert_vector<npu_device_fp16_t> {
|
||||
static float convert(HVX_Vector vec) {
|
||||
HVX_Vector vect = Q6_Vhf_equals_Vqf16(vec);
|
||||
uint16_t i = (vect[0] & 0xffff);
|
||||
return reinterpret_cast<__fp16 &>(i);
|
||||
}
|
||||
};
|
||||
|
||||
template <auto _DotFunc, bool _ShouldCacheSrc0>
|
||||
void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst,
|
||||
hexagon::compute_params * params) {
|
||||
using data_type0 = typename get_data_type<decltype(_DotFunc)>::data_type0;
|
||||
using data_type1 = typename get_data_type<decltype(_DotFunc)>::data_type1;
|
||||
|
||||
static_assert(!_IsQuantized || std::is_same_v<data_type0, hexagon::dequant_target_type>,
|
||||
"data_type0 must be the same as hexagon::dequant_target_type");
|
||||
|
||||
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0);
|
||||
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
|
||||
if (_IsQuantized && dequantize_row_func == nullptr) {
|
||||
if (_ShouldCacheSrc0 && dequantize_row_func == nullptr) {
|
||||
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
|
||||
return;
|
||||
}
|
||||
|
|
@ -61,7 +72,7 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
size_t src0_plane_cache_size = 0;
|
||||
uint8_t * src0_plane_cache_ptr = nullptr;
|
||||
const uint8_t * last_cached_plane_ptr = nullptr;
|
||||
if constexpr (_IsQuantized) {
|
||||
if constexpr (_ShouldCacheSrc0) {
|
||||
src0_plane_slice_row_count =
|
||||
std::min(params->get_vtcm_quota_size() / src0_actual_row_size, src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
|
|
@ -78,11 +89,12 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
DEVICE_LOG_DEBUG(
|
||||
"mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: "
|
||||
"%p(%zu)\n",
|
||||
src0_actual_row_size, src0_plane_slice_row_count, _IsQuantized, (void *) src0_plane_cache_ptr,
|
||||
src0_actual_row_size, src0_plane_slice_row_count, _ShouldCacheSrc0, (void *) src0_plane_cache_ptr,
|
||||
src0_plane_cache_size);
|
||||
|
||||
const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0);
|
||||
const size_t valid_row1_bytes = src1->get_ne(0) * sizeof(data_type1);
|
||||
const size_t valid_row1_bytes =
|
||||
src0->get_ne(0) * sizeof(data_type1); // src0 and src1 should have the same element count in the 1st dimension
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat);
|
||||
|
||||
uint8_t * dst_ptr = dst->get_write_buffer();
|
||||
|
|
@ -92,7 +104,7 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
return;
|
||||
}
|
||||
|
||||
constexpr bool should_fetch_src0_row = !_IsQuantized;
|
||||
constexpr bool should_fetch_src0_row = !_ShouldCacheSrc0;
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) {
|
||||
|
|
@ -102,24 +114,24 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2);
|
||||
for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second;
|
||||
col_idx += src0_plane_slice_row_count) {
|
||||
const auto actual_row_count =
|
||||
const int64_t actual_row_count =
|
||||
std::min<int64_t>(src0_plane_slice_row_count,
|
||||
start_end_element.second - col_idx); // number of rows in this slice
|
||||
const uint8_t * src0_plane =
|
||||
src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1);
|
||||
if constexpr (_IsQuantized) {
|
||||
if constexpr (_ShouldCacheSrc0) {
|
||||
if (last_cached_plane_ptr != src0_plane) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant);
|
||||
|
||||
for (int64_t ir = 0; ir < (int64_t) actual_row_count; ir++) {
|
||||
hexagon::l2fetch_row(src0_plane, src0->get_nb(1));
|
||||
for (int64_t ir = 0; ir < actual_row_count; ir++) {
|
||||
auto * src0_row = src0_plane + ir * src0->get_nb(1);
|
||||
if (ir + 1 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1));
|
||||
}
|
||||
|
||||
auto * dst_row = reinterpret_cast<hexagon::dequant_target_type *>(src0_plane_cache_ptr +
|
||||
ir * src0_actual_row_size);
|
||||
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_target_type *>(dst_row),
|
||||
auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size;
|
||||
dequantize_row_func(src0_row, reinterpret_cast<hexagon::dequant_target_type *>(cached_row_ptr),
|
||||
src0->get_ne(0));
|
||||
}
|
||||
|
||||
|
|
@ -138,34 +150,45 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
auto * src1_row = src1_plane + i1 * src1->get_nb(1);
|
||||
auto * dst_row = reinterpret_cast<float *>(dst_plane + i1 * dst->get_nb(1)) + col_idx;
|
||||
int64_t i0 = 0;
|
||||
for (; i0 + 1 < (int64_t) actual_row_count; i0 += 2) {
|
||||
for (; i0 + 1 < actual_row_count; i0 += 2) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
if constexpr (should_fetch_src0_row) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
dst_row[i0] = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
auto res0 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
|
||||
if (should_fetch_src0_row && i0 + 2 < (int64_t) actual_row_count) {
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res0);
|
||||
}
|
||||
|
||||
if (should_fetch_src0_row && i0 + 2 < actual_row_count) {
|
||||
hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes);
|
||||
}
|
||||
|
||||
// TODO: figure dst how to handle a entire row
|
||||
dst_row[i0 + 1] =
|
||||
_DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_actual_row_size),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
auto res1 = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row + src0_actual_row_size),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0 + 1] = convert_vector<data_type1>::convert(res1);
|
||||
}
|
||||
}
|
||||
|
||||
if (ip + 1 < start_end_plane.second) {
|
||||
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row1_bytes);
|
||||
}
|
||||
|
||||
if (i0 < (int64_t) actual_row_count) {
|
||||
if (i0 < actual_row_count) {
|
||||
auto * src0_row = src0_plane + i0 * src0_actual_row_size;
|
||||
dst_row[i0] = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
auto res = _DotFunc(reinterpret_cast<const data_type0 *>(src0_row),
|
||||
reinterpret_cast<const data_type1 *>(src1_row), (size_t) src0->get_ne(0));
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store);
|
||||
dst_row[i0] = convert_vector<data_type1>::convert(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -174,6 +197,25 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
dst->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
bool is_row_size_cacheable(const npu_device_tensor_spec & src) {
|
||||
const auto & type_traits = hexagon::get_type_traits(src.type);
|
||||
if (type_traits.to_float == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) cannot be cached, to_float is null\n",
|
||||
hexagon::get_type_name(src.type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t type_size = type_traits.is_quantized ? sizeof(hexagon::dequant_target_type) : type_traits.type_size;
|
||||
const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota();
|
||||
if (src.ne[0] * type_size > vtcm_thread_quota_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n",
|
||||
hexagon::get_type_name(src.type), (long) src.ne[0], vtcm_thread_quota_size);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
|
||||
if (src1.type != NPU_DATA_TYPE_F32 && src1.type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src1 is not F32\n",
|
||||
|
|
@ -194,10 +236,7 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n
|
|||
return false;
|
||||
}
|
||||
|
||||
const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota();
|
||||
if (src0.ne[0] * sizeof(hexagon::dequant_target_type) > vtcm_thread_quota_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n",
|
||||
hexagon::get_type_name(src0.type), (long) src0.ne[0], vtcm_thread_quota_size);
|
||||
if (!is_row_size_cacheable(src0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -208,9 +247,8 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n
|
|||
|
||||
bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<float>();
|
||||
const auto * src0_ptr = is_src0_quantized ?
|
||||
src1->get_read_buffer_as<npu_device_fp16_t>() :
|
||||
src0->get_read_buffer_as<npu_device_fp16_t>(); // skip src0 for quantized tensors
|
||||
const auto * src0_ptr =
|
||||
is_src0_quantized ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>(); // skip src0 for quantized tensors
|
||||
|
||||
if (!hexagon::is_f16_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
|
|
@ -223,13 +261,23 @@ bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::ten
|
|||
|
||||
bool is_mul_mat_f16_f16_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) {
|
||||
const auto * src1_ptr = src1->get_read_buffer_as<npu_device_fp16_t>();
|
||||
const auto * src0_ptr = is_src0_quantized ? src1_ptr : src0->get_read_buffer_as<npu_device_fp16_t>();
|
||||
const auto * src0_ptr = is_src0_quantized ? nullptr : src0->get_read_buffer_as<npu_device_fp16_t>();
|
||||
|
||||
if (!hexagon::is_f16_f16_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_src0_quantized && !hexagon::is_size_aligned(src0->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0 tensor nb[1] is not aligned: %zu\n", src0->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_size_aligned(src1->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src1 tensor nb[1] is not aligned: %zu\n", src1->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
|
@ -243,6 +291,16 @@ bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::ten
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_size_aligned(src0->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0 tensor nb[1] is not aligned: %zu\n", src0->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hexagon::is_size_aligned(src1->get_nb(1))) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src1 tensor nb[1] is not aligned: %zu\n", src1->get_nb(1));
|
||||
return false;
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0));
|
||||
return true;
|
||||
}
|
||||
|
|
@ -250,30 +308,32 @@ bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::ten
|
|||
typedef void (*mul_mat_func_type)(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst,
|
||||
hexagon::compute_params * params);
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16F32Funcs[2][2] = {
|
||||
{
|
||||
// non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f32, false>, // F32 * F32 unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f32, false>, // F32 * F32 aligned
|
||||
},
|
||||
{
|
||||
// quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f32, true>, // F32 * F32 quantized aligned
|
||||
},
|
||||
constexpr const mul_mat_func_type kMulMatF32F32CachedFuncs[2] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, true>, // F32 * F32 quantized aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16Funcs[2][2] = {
|
||||
{
|
||||
// non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f16, false>, // F16 * F16 unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f16, false>, // F16 * F16 aligned
|
||||
},
|
||||
{
|
||||
// quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f16, true>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f16_f16, true>, // F16 * F16 quantized aligned
|
||||
},
|
||||
constexpr const mul_mat_func_type kMulMatF32F32Funcs[2] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f32_f32, false>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f32_f32, false>, // F32 * F32 quantized aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16CachedFuncs[2] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, true>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, true>, // F16 * F16 quantized aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16Funcs[2] = {
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf16_f16_f16, false>, // F16 * F16 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf16_f16_f16, false>, // F16 * F16 quantized aligned
|
||||
};
|
||||
|
||||
constexpr const mul_mat_func_type kMulMatF16F32Funcs[2] = {
|
||||
// quantized and non-quantized
|
||||
mul_mat_impl<hexagon::vec_dot_product_vqf32_f16_f32, true>, // F32 * F32 quantized unaligned
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_vqf32_f16_f32, true>, // F32 * F32 quantized aligned
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
|
@ -297,22 +357,26 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
|
|||
}
|
||||
|
||||
const bool is_src0_quantized = is_quantized_type(src0->get_type());
|
||||
const bool should_cache_src0 = is_src0_quantized || src1->get_ne(1) > 1;
|
||||
switch (src1->get_type()) {
|
||||
case NPU_DATA_TYPE_F32:
|
||||
if (is_src0_quantized || src0->get_type() == NPU_DATA_TYPE_F16) {
|
||||
kMulMatF16F32Funcs[is_src0_quantized][is_mul_mat_f16_f32_src_tensors_aligned(
|
||||
src0, src1, is_src0_quantized)](src0, src1, out, params);
|
||||
kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1,
|
||||
out, params);
|
||||
} else if (should_cache_src0) {
|
||||
kMulMatF32F32CachedFuncs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params);
|
||||
} else {
|
||||
if (is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)) {
|
||||
mul_mat_impl<hexagon::vec_dot_product_aligned_f32_f32, false>(src0, src1, out, params);
|
||||
} else {
|
||||
mul_mat_impl<hexagon::vec_dot_product_f32_f32, false>(src0, src1, out, params);
|
||||
}
|
||||
kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params);
|
||||
}
|
||||
return true;
|
||||
case NPU_DATA_TYPE_F16:
|
||||
kMulMatF16Funcs[is_src0_quantized][is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](
|
||||
src0, src1, out, params);
|
||||
if (should_cache_src0) {
|
||||
kMulMatF16CachedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](
|
||||
src0, src1, out, params);
|
||||
} else {
|
||||
kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1, out,
|
||||
params);
|
||||
}
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -270,8 +270,9 @@ bool rope_impl(hexagon::tensor * out, hexagon::compute_params * params) {
|
|||
}
|
||||
} else {
|
||||
// fill the remain channels with data from src tensor
|
||||
memcpy(dst_row + n_dims * out->get_nb(0), src0_row + n_dims * src0->get_nb(0),
|
||||
(out->get_ne(0) - n_dims) * sizeof(float));
|
||||
hexagon::vec_cpy_f32(reinterpret_cast<const float *>(src0_row + n_dims * src0->get_nb(0)),
|
||||
reinterpret_cast<float *>(dst_row + n_dims * out->get_nb(0)),
|
||||
out->get_ne(0) - n_dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,7 +60,8 @@ class tensor {
|
|||
memcpy(_op_params, config.params, sizeof(_op_params));
|
||||
for (size_t i = 0; i < DEVICE_TENSOR_MAX_SRC; ++i) {
|
||||
auto src_handle = config.src_handles[i];
|
||||
_src[i] = (src_handle ? reinterpret_cast<tensor *>(src_handle) : nullptr);
|
||||
_src[i] = (src_handle != npu_device_INVALID_DEVICE_TENSOR_HANDLE ? reinterpret_cast<tensor *>(src_handle) :
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,22 +28,26 @@ inline npu_device_fp16_t to_fp16(const float src) {
|
|||
return reinterpret_cast<const npu_device_fp16_t &>(f16_value);
|
||||
}
|
||||
|
||||
template <typename _TStruct, size_t _Count, auto _MemberPtr> inline HVX_Vector load_into_vector(const _TStruct * src) {
|
||||
static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load");
|
||||
|
||||
const HVX_Vector * qs0 = reinterpret_cast<const HVX_Vector *>(&(src->*_MemberPtr));
|
||||
HVX_Vector prev = *qs0;
|
||||
HVX_Vector curr = hexagon::is_addr_aligned(qs0) ? Q6_V_vzero() : *(qs0 + 1);
|
||||
return Q6_V_valign_VVR(curr, prev, (size_t) qs0);
|
||||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector load_block_generic(const _TBlock & src) {
|
||||
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock), "wrong q4_0 block size/padding");
|
||||
|
||||
const HVX_Vector * qs0 = reinterpret_cast<const HVX_Vector *>(src.qs);
|
||||
const HVX_Vector * qs1 = qs0 + 1;
|
||||
return Q6_V_valign_VVR(*qs1, *qs0, (size_t) src.qs);
|
||||
return load_into_vector<_TBlock, 1, &_TBlock::qs>(&src);
|
||||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector load_dual_block_generic(const _TBlock * srcs) {
|
||||
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong q4_0 block size/padding");
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
|
||||
|
||||
const HVX_Vector * qs0 = reinterpret_cast<const HVX_Vector *>(srcs->qs);
|
||||
const HVX_Vector * qs1 = qs0 + 1;
|
||||
HVX_Vector blocks = Q6_V_valign_VVR(*qs1, *qs0, (size_t) srcs->qs);
|
||||
HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock));
|
||||
HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs);
|
||||
HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock));
|
||||
return Q6_V_lo_W(Q6_W_vshuff_VVR(block1, blocks, kSizeOfQs));
|
||||
}
|
||||
|
||||
|
|
@ -51,15 +55,14 @@ template <typename _TBlock> inline HVX_Vector load_qual_block_generic(const _TBl
|
|||
static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 4, "wrong q4_0 block size/padding");
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs);
|
||||
|
||||
const HVX_Vector * qs0 = reinterpret_cast<const HVX_Vector *>(srcs->qs);
|
||||
const HVX_Vector * qs1 = qs0 + 1;
|
||||
HVX_Vector blocks = Q6_V_valign_VVR(*qs1, *qs0, (size_t) srcs->qs);
|
||||
HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock));
|
||||
HVX_Vector block2 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 2);
|
||||
HVX_Vector block3 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 3);
|
||||
HVX_Vector blocks = load_into_vector<_TBlock, 4, &_TBlock::qs>(srcs);
|
||||
HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock));
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(block1, blocks, kSizeOfQs);
|
||||
|
||||
HVX_Vector block2 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 2);
|
||||
HVX_Vector block3 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 3);
|
||||
HVX_VectorPair qp1 = Q6_W_vshuff_VVR(block3, block2, kSizeOfQs);
|
||||
|
||||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(block1, blocks, kSizeOfQs);
|
||||
HVX_VectorPair qp1 = Q6_W_vshuff_VVR(block3, block2, kSizeOfQs);
|
||||
return Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_lo_W(qp1), Q6_V_lo_W(qp0), kSizeOfQs * 2));
|
||||
}
|
||||
|
||||
|
|
@ -381,17 +384,22 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
|
|||
HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4));
|
||||
q_lo = Q6_Vb_vsub_VbVb(Q6_V_lo_W(qp0), minus);
|
||||
qp0 = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0));
|
||||
q_hi = Q6_Vhf_equals_Vh(Q6_V_hi_W(qp0));
|
||||
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01);
|
||||
q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scales23);
|
||||
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0));
|
||||
q_hi = Q6_Vhf_equals_Vh(Q6_V_hi_W(qp0));
|
||||
|
||||
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01);
|
||||
q_lo = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
|
||||
q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scales23);
|
||||
q_hi = Q6_Vhf_equals_Vqf16(q_hi);
|
||||
|
||||
if constexpr (_IsDstAligned) {
|
||||
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(q_hi);
|
||||
reinterpret_cast<HVX_Vector *>(dst_ptr)[0] = q_lo;
|
||||
reinterpret_cast<HVX_Vector *>(dst_ptr)[1] = q_hi;
|
||||
} else {
|
||||
reinterpret_cast<HVX_UVector *>(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(q_hi);
|
||||
reinterpret_cast<HVX_UVector *>(dst_ptr)[0] = q_lo;
|
||||
reinterpret_cast<HVX_UVector *>(dst_ptr)[1] = q_hi;
|
||||
}
|
||||
|
||||
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type) * 2;
|
||||
|
|
@ -412,11 +420,12 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
|
|||
qp0 = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0));
|
||||
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01);
|
||||
q_lo = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
|
||||
if constexpr (_IsDstAligned) {
|
||||
*reinterpret_cast<HVX_Vector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
*reinterpret_cast<HVX_Vector *>(dst_ptr) = q_lo;
|
||||
} else {
|
||||
*reinterpret_cast<HVX_UVector *>(dst_ptr) = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
*reinterpret_cast<HVX_UVector *>(dst_ptr) = q_lo;
|
||||
}
|
||||
|
||||
dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type);
|
||||
|
|
@ -434,12 +443,12 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d
|
|||
qp0 = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0));
|
||||
q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales);
|
||||
q_lo = Q6_Vhf_equals_Vqf16(q_lo);
|
||||
|
||||
if constexpr (_IsDstAligned) {
|
||||
hexagon::q6op_vstu_variable_aligned<hexagon::kBytesPerVector / 2>(dst_ptr, Q6_Vhf_equals_Vqf16(q_lo));
|
||||
hexagon::q6op_vstu_variable_aligned<hexagon::kBytesPerVector / 2>(dst_ptr, q_lo);
|
||||
} else {
|
||||
hexagon::q6op_vstu_variable_ARV<hexagon::kBytesPerVector / 2>(
|
||||
dst_ptr,
|
||||
Q6_Vhf_equals_Vqf16(q_lo)); // TODO: opt the store
|
||||
hexagon::q6op_vstu_variable_ARV<hexagon::kBytesPerVector / 2>(dst_ptr, q_lo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -488,26 +497,24 @@ void dequantize_row_q4_K(const void * src, hexagon::dequant_target_type * dst, s
|
|||
}
|
||||
}
|
||||
|
||||
template <typename _TFunc> struct dot_func_traits {};
|
||||
void copy_row_f16(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
hexagon::vec_cpy_f16(reinterpret_cast<const npu_device_fp16_t *>(src), dst, count);
|
||||
}
|
||||
|
||||
template <typename _TData> struct dot_func_traits<float (*)(_TData, _TData, size_t)> {
|
||||
using param_type = std::remove_const_t<std::remove_pointer_t<_TData>>;
|
||||
};
|
||||
|
||||
template <auto _DotFunc> float wrap_dot_func(const void * src0, const void * src1, size_t count) {
|
||||
using param_type = typename dot_func_traits<decltype(_DotFunc)>::param_type;
|
||||
|
||||
auto * src0_typed = reinterpret_cast<const param_type *>(src0);
|
||||
auto * src1_typed = reinterpret_cast<const param_type *>(src1);
|
||||
return _DotFunc(src0_typed, src1_typed, count);
|
||||
void copy_row_f32(const void * src, hexagon::dequant_target_type * dst, size_t count) {
|
||||
hexagon::vec_cpy_f32(reinterpret_cast<const float *>(src), reinterpret_cast<float *>(dst), count);
|
||||
}
|
||||
|
||||
constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
|
||||
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, nullptr, nullptr,
|
||||
wrap_dot_func<hexagon::vec_dot_product_f32_f32> },
|
||||
{ NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, nullptr, quantize_row_fp16,
|
||||
wrap_dot_func<hexagon::vec_dot_product_f16_f16> },
|
||||
{ NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false, nullptr, nullptr, nullptr },
|
||||
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32, nullptr,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f32_f32>,
|
||||
hexagon::type_erase_dot_func<hexagon::is_f32_f32_dot_product_aligned> },
|
||||
{ NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, copy_row_f16, quantize_row_fp16,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16>,
|
||||
hexagon::type_erase_dot_func<hexagon::vec_dot_product_aligned_f16_f16>,
|
||||
hexagon::type_erase_dot_func<hexagon::is_f16_f16_dot_product_aligned> },
|
||||
{ NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false },
|
||||
{ NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0,
|
||||
quantize_row_q8_0 },
|
||||
{ NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0,
|
||||
|
|
@ -552,4 +559,14 @@ const device_type_traits & get_type_traits(npu_device_tensor_data_type type) {
|
|||
return kDeviceTypeTraits[type];
|
||||
}
|
||||
|
||||
size_t get_dequantized_row_size(const tensor * tensor) {
|
||||
if (!is_quantized_type(tensor->get_type())) {
|
||||
return tensor->get_nb(1); // for f32 and f16
|
||||
}
|
||||
|
||||
auto row_elems_count = tensor->get_ne(0);
|
||||
return hexagon::get_aligned_size(
|
||||
row_elems_count * sizeof(dequant_target_type)); // dequant_target_type is currently restricted to f32
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ bool init_f16_f32_table(float * table, size_t count);
|
|||
typedef void (*quantize_row_type)(const float * src, void * dst, size_t count);
|
||||
typedef void (*dequantize_row_type)(const void * src, dequant_target_type * dst, size_t count);
|
||||
typedef float (*vec_dot_type)(const void * src0, const void * src1, size_t count);
|
||||
typedef bool (*can_use_aligned_vec_dot_type)(const void * src0, const void * src1, size_t count);
|
||||
|
||||
struct device_type_traits {
|
||||
npu_device_tensor_data_type type;
|
||||
|
|
@ -20,9 +21,11 @@ struct device_type_traits {
|
|||
size_t type_size;
|
||||
bool is_quantized;
|
||||
|
||||
dequantize_row_type to_float;
|
||||
quantize_row_type from_float;
|
||||
vec_dot_type vec_dot;
|
||||
dequantize_row_type to_float;
|
||||
quantize_row_type from_float;
|
||||
vec_dot_type vec_dot;
|
||||
vec_dot_type vec_dot_aligned;
|
||||
can_use_aligned_vec_dot_type can_use_aligned_vec_dot;
|
||||
};
|
||||
|
||||
const device_type_traits & get_type_traits(npu_device_tensor_data_type type);
|
||||
|
|
@ -31,14 +34,7 @@ inline bool is_quantized_type(npu_device_tensor_data_type type) {
|
|||
return get_type_traits(type).is_quantized;
|
||||
}
|
||||
|
||||
inline size_t get_dequantized_row_size(const tensor * tensor) {
|
||||
if (!is_quantized_type(tensor->get_type())) {
|
||||
return tensor->get_nb(1); // for f32 and f16
|
||||
}
|
||||
|
||||
auto row_elems_count = tensor->get_ne(0);
|
||||
return row_elems_count * sizeof(dequant_target_type); // currently only f32 is supported
|
||||
}
|
||||
size_t get_dequantized_row_size(const tensor * tensor);
|
||||
|
||||
inline const char * get_type_name(npu_device_tensor_data_type type) {
|
||||
return get_type_traits(type).type_name;
|
||||
|
|
|
|||
|
|
@ -344,8 +344,10 @@ inline auto make_scoped_perf_timer(const char * format, ...) {
|
|||
} // namespace hexagon
|
||||
|
||||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
# define _MAKE_VARIABLE_NAME2(name, postfix) name##postfix
|
||||
# define _MAKE_VARIABLE_NAME(name, postfix) _MAKE_VARIABLE_NAME2(name, postfix)
|
||||
# define DEVICE_SCOPED_PERFORMANCE_TRACKER(fmt, ...) \
|
||||
auto __npu_timer_##__LINE__ = hexagon::make_scoped_perf_timer(fmt, __VA_ARGS__)
|
||||
auto _MAKE_VARIABLE_NAME(__npu_timer_, __LINE__) = hexagon::make_scoped_perf_timer(fmt, __VA_ARGS__)
|
||||
#else
|
||||
# define DEVICE_SCOPED_PERFORMANCE_TRACKER(fmt, ...) ((void) 0)
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,321 +0,0 @@
|
|||
#include "vec_ops.hpp"
|
||||
|
||||
#include "util.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename _TElem, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), float (*_ReduceFunc)(HVX_Vector)>
|
||||
inline float vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
prev0 = Q6_V_hi_W(curr0);
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(l0, l1), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(h0, h1), sum1);
|
||||
}
|
||||
|
||||
sum = _AddFunc(sum0, sum1);
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = *src0_vec_ptr++;
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
||||
sum = _AddFunc(_MpyFunc(s0, s1), sum);
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
|
||||
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
||||
sum = _AddFunc(_MpyFunc(s0, s1), sum);
|
||||
}
|
||||
|
||||
const size_t leftover_bytes = leftover * sizeof(_TElem);
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src0_vec_ptr :
|
||||
prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes), sum);
|
||||
}
|
||||
|
||||
return _ReduceFunc(sum);
|
||||
}
|
||||
|
||||
template <typename _TElem, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), float (*_ReduceFunc)(HVX_Vector)>
|
||||
inline float vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr0), Q6_V_lo_W(curr1)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr0), Q6_V_hi_W(curr1)), sum1);
|
||||
}
|
||||
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_Vector curr1 = src1_vec_ptr[0];
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(curr0, curr1), sum0);
|
||||
}
|
||||
|
||||
return _ReduceFunc(_AddFunc(sum0, sum1));
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_mpy_qf32(HVX_Vector src0, HVX_Vector src1) {
|
||||
return Q6_Vqf32_vmpy_VsfVsf(src0, src1);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_add_qf32(HVX_Vector sum, HVX_Vector result) {
|
||||
return Q6_Vqf32_vadd_Vqf32Vqf32(sum, result);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_mpy_qf16(HVX_Vector src0, HVX_Vector src1) {
|
||||
return Q6_Vqf16_vmpy_VhfVhf(src0, src1);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) {
|
||||
return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1, HVX_VectorPair (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
float (*_ReduceFunc)(HVX_Vector)>
|
||||
inline float vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
||||
static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2,
|
||||
"Element size mismatch: _TElem1 must be twice the size of _TElem0");
|
||||
static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0,
|
||||
"Element size mismatch: _TElem1 must be a multiple of _TElem0");
|
||||
|
||||
constexpr const size_t kElementsPerVector0 = hexagon::kBytesPerVector / sizeof(_TElem0);
|
||||
constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1);
|
||||
|
||||
constexpr const __fp16 kOne = 1.0f;
|
||||
const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast<const uint16_t &>(kOne));
|
||||
|
||||
const _TElem0 * const src0_ptr_end = src0 + count;
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1;
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
while (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
HVX_VectorPair s0_pair = _ExpandFunc(s0, kOneV);
|
||||
prev0 = curr0;
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
src0_vec_ptr++;
|
||||
src1_vec_ptr += 2;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), l1), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(s0_pair), h1), sum1);
|
||||
}
|
||||
|
||||
sum = _AddFunc(sum0, sum1);
|
||||
const size_t leftover1 = count % kElementsPerVector1;
|
||||
if ((src1_vec_ptr_end - ((HVX_Vector *) src1)) > 0) {
|
||||
// handle the last vector
|
||||
const bool should_fetch_src0 =
|
||||
reinterpret_cast<const _TElem0 *>(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end;
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_VectorPair s0_pair = _ExpandFunc(s0, kOneV);
|
||||
|
||||
const bool has_remaining_src1_vector = src1_vec_ptr_end - src1_vec_ptr > 0;
|
||||
if (has_remaining_src1_vector) {
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev1 = curr1;
|
||||
|
||||
// should_handle_last_vector will be always true here
|
||||
sum = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), s1), sum);
|
||||
}
|
||||
|
||||
bool should_fetch_src1 = leftover1 != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
||||
sum = _AddFunc(_MpyFunc(has_remaining_src1_vector ? Q6_V_hi_W(s0_pair) : Q6_V_lo_W(s0_pair), s1), sum);
|
||||
}
|
||||
|
||||
const size_t leftover0 = count % kElementsPerVector0;
|
||||
const size_t leftover_bytes1 = leftover1 * sizeof(_TElem1);
|
||||
if (leftover1 > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr0 =
|
||||
reinterpret_cast<const _TElem0 *>(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end ? *src0_vec_ptr : prev0;
|
||||
HVX_Vector curr1 = (leftover_bytes1 + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
HVX_VectorPair curr0_pair = _ExpandFunc(curr0, kOneV);
|
||||
|
||||
curr0 = leftover1 == leftover0 ? Q6_V_lo_W(curr0_pair) : Q6_V_hi_W(curr0_pair);
|
||||
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes1), sum);
|
||||
}
|
||||
|
||||
return _ReduceFunc(sum);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1, HVX_VectorPair (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
float (*_ReduceFunc)(HVX_Vector)>
|
||||
inline float vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
||||
static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2,
|
||||
"Element size mismatch: _TElem1 must be twice the size of _TElem0");
|
||||
static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0,
|
||||
"Element size mismatch: _TElem1 must be a multiple of _TElem0");
|
||||
|
||||
constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1);
|
||||
|
||||
constexpr const __fp16 kOne = 1.0f;
|
||||
const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast<const uint16_t &>(kOne));
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1;
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
{
|
||||
HVX_Vector sum2 = Q6_V_vzero();
|
||||
HVX_Vector sum3 = Q6_V_vzero();
|
||||
|
||||
while (src1_vec_ptr_end - src1_vec_ptr > 3) {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr10 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
HVX_VectorPair curr11 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[1];
|
||||
|
||||
HVX_VectorPair curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV);
|
||||
HVX_VectorPair curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV);
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 4;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr00), Q6_V_lo_W(curr10)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr00), Q6_V_hi_W(curr10)), sum1);
|
||||
sum2 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr01), Q6_V_lo_W(curr11)), sum2);
|
||||
sum3 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr01), Q6_V_hi_W(curr11)), sum3);
|
||||
}
|
||||
|
||||
sum0 = _AddFunc(sum0, sum2);
|
||||
sum1 = _AddFunc(sum1, sum3);
|
||||
}
|
||||
|
||||
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_VectorPair s0_pair = _ExpandFunc(curr0, kOneV);
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), Q6_V_lo_W(curr1)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(s0_pair), Q6_V_hi_W(curr1)), sum1);
|
||||
}
|
||||
|
||||
return _ReduceFunc(_AddFunc(sum0, sum1));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_impl<float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_aligned_impl<float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
||||
return vec_dot_product_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(src0, src1,
|
||||
count);
|
||||
}
|
||||
|
||||
float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
||||
return vec_dot_product_aligned_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
|
||||
vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
|
||||
vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -31,6 +31,10 @@ inline bool is_addr_aligned(const void * addr) {
|
|||
return unaligned_bytes(addr) == 0;
|
||||
}
|
||||
|
||||
inline bool is_size_aligned(size_t size) {
|
||||
return (size & kAlignMask) == 0;
|
||||
}
|
||||
|
||||
inline float get_flt0_from_fltv(HVX_Vector vect) {
|
||||
static_assert(sizeof(vect[0]) == sizeof(float), "vect[0] should be a float");
|
||||
int32_t i = vect[0];
|
||||
|
|
@ -157,31 +161,25 @@ inline HVX_VectorPair hvx_vqf32_convert_vhf(HVX_Vector vxl) {
|
|||
return qhmath_hvx_vqf32_convert_vqf16(qhmath_hvx_vqf16_convert_vhf(vxl));
|
||||
}
|
||||
|
||||
inline HVX_VectorPair hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) {
|
||||
HVX_VectorPair res = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vxl), one);
|
||||
HVX_Vector vxl_w = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(res));
|
||||
HVX_Vector vxh_w = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(res));
|
||||
return Q6_W_vcombine_VV(vxh_w, vxl_w);
|
||||
using HVX_Vector_Dual = std::pair<HVX_Vector, HVX_Vector>;
|
||||
|
||||
inline HVX_Vector_Dual hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) {
|
||||
HVX_VectorPair res = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vxl), one);
|
||||
return {
|
||||
Q6_Vsf_equals_Vqf32(Q6_V_lo_W(res)),
|
||||
Q6_Vsf_equals_Vqf32(Q6_V_hi_W(res)),
|
||||
};
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) {
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
static_assert(kFloatsPerVector == 32 || kFloatsPerVector == 16, "kFloatsPerVector should be 16 or 32");
|
||||
|
||||
// TODO: do we have a better way to do the reduction?
|
||||
switch (kFloatsPerVector) {
|
||||
default:
|
||||
case 32:
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 16 * sizeof(float)));
|
||||
// fallthrough
|
||||
case 16:
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 8 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 4 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 2 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, sizeof(float)));
|
||||
break;
|
||||
}
|
||||
static_assert(kFloatsPerVector == 32, "kFloatsPerVector should be 32");
|
||||
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 16 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 8 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 4 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 2 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, sizeof(float)));
|
||||
return sums;
|
||||
}
|
||||
|
||||
|
|
@ -191,23 +189,14 @@ inline float vec_reduction_f32_qf32(HVX_Vector sums) {
|
|||
|
||||
inline HVX_Vector vec_reduction_qf16(HVX_Vector sums) {
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(npu_device_fp16_t);
|
||||
static_assert(kFloatsPerVector == 64 || kFloatsPerVector == 32, "kFloatsPerVector should be 32 or 64");
|
||||
|
||||
// TODO: do we have a better way to do the reduction?
|
||||
switch (kFloatsPerVector) {
|
||||
default:
|
||||
case 64:
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 32 * sizeof(npu_device_fp16_t)));
|
||||
// fallthrough
|
||||
case 32:
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 16 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 8 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 4 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 2 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, sizeof(npu_device_fp16_t)));
|
||||
break;
|
||||
}
|
||||
static_assert(kFloatsPerVector == 64, "kFloatsPerVector should be 64");
|
||||
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 32 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 16 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 8 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 4 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 2 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, sizeof(npu_device_fp16_t)));
|
||||
return sums;
|
||||
}
|
||||
|
||||
|
|
@ -221,62 +210,6 @@ inline HVX_Vector hvx_scale_f32(float scale) {
|
|||
return Q6_V_vsplat_R(reinterpret_cast<const uint32_t &>(scale));
|
||||
}
|
||||
|
||||
template <HVX_Vector (*_Func)(HVX_Vector, HVX_UVector *, HVX_Vector), HVX_Vector (*_FuncScaleConvert)(float),
|
||||
typename _TParam>
|
||||
inline void vec_scale_impl(const _TParam * src, float scale, _TParam * dst, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TParam);
|
||||
|
||||
HVX_Vector * src_vec_ptr = ((HVX_Vector *) src);
|
||||
HVX_Vector * const src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector);
|
||||
HVX_UVector * dst_vec_ptr = ((HVX_UVector *) dst); // TODO: opt the unaligned case?
|
||||
HVX_Vector prev = *src_vec_ptr++;
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
const size_t leftover_bytes = leftover * sizeof(_TParam);
|
||||
|
||||
HVX_Vector scale_vec = _FuncScaleConvert(scale);
|
||||
|
||||
while (src_vec_end - src_vec_ptr > 1) {
|
||||
HVX_VectorPair curr = reinterpret_cast<HVX_VectorPair *>(src_vec_ptr)[0];
|
||||
src_vec_ptr += 2;
|
||||
|
||||
HVX_Vector lo = Q6_V_valign_VVR(Q6_V_lo_W(curr), prev, (size_t) src);
|
||||
HVX_Vector hi = Q6_V_valign_VVR(Q6_V_hi_W(curr), Q6_V_lo_W(curr), (size_t) src);
|
||||
|
||||
dst_vec_ptr[0] = _Func(lo, dst_vec_ptr, scale_vec);
|
||||
dst_vec_ptr[1] = _Func(hi, dst_vec_ptr + 1, scale_vec);
|
||||
|
||||
dst_vec_ptr += 2;
|
||||
prev = Q6_V_hi_W(curr);
|
||||
}
|
||||
|
||||
if (src_vec_end - src_vec_ptr > 0) {
|
||||
HVX_Vector curr = *src_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec);
|
||||
dst_vec_ptr++;
|
||||
prev = curr;
|
||||
}
|
||||
|
||||
if ((src_vec_end - ((HVX_Vector *) src)) > 0) {
|
||||
// handle the last vector
|
||||
bool should_fetch_next = leftover == 0 && hexagon::is_addr_aligned(src_vec_ptr);
|
||||
HVX_Vector curr = should_fetch_next ? prev : *src_vec_ptr;
|
||||
src_vec_ptr = should_fetch_next ? src_vec_ptr : src_vec_ptr + 1;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec);
|
||||
dst_vec_ptr++;
|
||||
prev = curr;
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr =
|
||||
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
|
||||
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _Func(curr, dst_vec_ptr, scale_vec));
|
||||
}
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_vec_scale_f32_f32(HVX_Vector src, HVX_UVector *, HVX_Vector scale_vec) {
|
||||
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(src, scale_vec));
|
||||
}
|
||||
|
|
@ -288,14 +221,6 @@ inline HVX_Vector hvx_vec_mad_f32_f32(HVX_Vector src, HVX_UVector * dst_ptr, HVX
|
|||
return Q6_Vsf_equals_Vqf32(src);
|
||||
}
|
||||
|
||||
inline void vec_scale_f32(const float * src, float scale, float * dst, size_t count) {
|
||||
vec_scale_impl<hvx_vec_scale_f32_f32, hvx_scale_f32, float>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
inline void vec_mad_f32(const float * src, float scale, float * dst, size_t count) {
|
||||
vec_scale_impl<hvx_vec_mad_f32_f32, hvx_scale_f32, float>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_scale_f16(float scale) {
|
||||
__fp16 f16_scale = scale;
|
||||
return Q6_Vh_vsplat_R(reinterpret_cast<const npu_device_fp16_t &>(f16_scale));
|
||||
|
|
@ -312,19 +237,65 @@ inline HVX_Vector hvx_vec_mad_f16_f16(HVX_Vector src, HVX_UVector * dst_ptr, HVX
|
|||
return Q6_Vhf_equals_Vqf16(result);
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_nop(float scale) {
|
||||
return HVX_Vector();
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_passthru(HVX_Vector src, HVX_UVector *, HVX_Vector) {
|
||||
return src;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
||||
#include "vec_ops.inl"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
inline void vec_scale_f32(const float * src, float scale, float * dst, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_scale_impl<hvx_vec_scale_f32_f32, hvx_scale_f32, float>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
inline void vec_mad_f32(const float * src, float scale, float * dst, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_scale_impl<hvx_vec_mad_f32_f32, hvx_scale_f32, float>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
inline void vec_cpy_f32(const float * src, float * dst, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_scale_impl<hvx_passthru, hvx_nop, float>(src, 0, dst, count);
|
||||
}
|
||||
|
||||
inline void vec_zero_f32(float * src, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_zero_impl<float>(src, count);
|
||||
}
|
||||
|
||||
inline void vec_scale_f16(const npu_device_fp16_t * src, float scale, npu_device_fp16_t * dst, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_scale_impl<hvx_vec_scale_f16_f16, hvx_scale_f16, npu_device_fp16_t>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
inline void vec_mad_f16(const npu_device_fp16_t * src, float scale, npu_device_fp16_t * dst, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_scale_impl<hvx_vec_mad_f16_f16, hvx_scale_f16, npu_device_fp16_t>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
inline void vec_cpy_f16(const npu_device_fp16_t * src, npu_device_fp16_t * dst, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_scale_impl<hvx_passthru, hvx_nop, npu_device_fp16_t>(src, 0, dst, count);
|
||||
}
|
||||
|
||||
inline void vec_zero_f16(npu_device_fp16_t * src, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
vec_zero_impl<npu_device_fp16_t>(src, count);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1>
|
||||
inline bool is_dot_product_aligned(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) <= sizeof(_TElem1), "src0 should be smaller than src1");
|
||||
|
||||
if (!hexagon::is_addr_aligned(src0) || !hexagon::is_addr_aligned(src1)) {
|
||||
if ((src0 && !hexagon::is_addr_aligned(src0)) || (src1 && !hexagon::is_addr_aligned(src1))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -335,26 +306,107 @@ inline bool is_dot_product_aligned(const _TElem0 * src0, const _TElem1 * src1, s
|
|||
return true;
|
||||
}
|
||||
|
||||
float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count);
|
||||
float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count);
|
||||
inline HVX_Vector vec_dot_product_vqf32_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_impl<float, HVX_Vector, vec_mpy_qf32, vec_add_qf32, vec_reduction_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_dot_product_aligned_vqf32_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_aligned_impl<float, HVX_Vector, vec_mpy_qf32, vec_add_qf32, vec_reduction_qf32>(src0, src1,
|
||||
count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_impl<float, float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_aligned_impl<float, float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1,
|
||||
count);
|
||||
}
|
||||
|
||||
inline bool is_f32_f32_dot_product_aligned(const float * src0, const float * src1, size_t count) {
|
||||
return is_dot_product_aligned<float, float>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count);
|
||||
float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count);
|
||||
inline HVX_Vector vec_dot_product_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_impl<npu_device_fp16_t, HVX_Vector, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_dot_product_aligned_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_aligned_impl<npu_device_fp16_t, HVX_Vector, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_impl<npu_device_fp16_t, float, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_aligned_impl<npu_device_fp16_t, float, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1,
|
||||
size_t count) {
|
||||
return is_dot_product_aligned<npu_device_fp16_t, npu_device_fp16_t>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count);
|
||||
float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count);
|
||||
inline HVX_Vector vec_dot_product_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, HVX_Vector, hvx_vsf_convert_vhf, vec_mpy_qf32,
|
||||
vec_add_qf32, vec_reduction_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1,
|
||||
size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, HVX_Vector, hvx_vsf_convert_vhf, vec_mpy_qf32,
|
||||
vec_add_qf32, vec_reduction_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
|
||||
vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
using namespace hexagon::vec;
|
||||
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, float, hvx_vsf_convert_vhf, vec_mpy_qf32,
|
||||
vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
|
||||
}
|
||||
|
||||
inline bool is_f16_f32_dot_product_aligned(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
||||
return is_dot_product_aligned<npu_device_fp16_t, float>(src0, src1, count);
|
||||
}
|
||||
|
||||
template <typename _TFunc> struct dot_func_traits {};
|
||||
|
||||
template <typename _TData, typename _TReturn> struct dot_func_traits<_TReturn (*)(_TData, _TData, size_t)> {
|
||||
using param_type = std::remove_const_t<std::remove_pointer_t<_TData>>;
|
||||
using return_type = _TReturn;
|
||||
};
|
||||
|
||||
template <auto _DotFunc, typename _TReturn = typename dot_func_traits<decltype(_DotFunc)>::return_type>
|
||||
_TReturn type_erase_dot_func(const void * src0, const void * src1, size_t count) {
|
||||
using param_type = typename dot_func_traits<decltype(_DotFunc)>::param_type;
|
||||
|
||||
auto * src0_typed = reinterpret_cast<const param_type *>(src0);
|
||||
auto * src1_typed = reinterpret_cast<const param_type *>(src1);
|
||||
return _DotFunc(src0_typed, src1_typed, count);
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -0,0 +1,499 @@
|
|||
#pragma once
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
|
||||
namespace hexagon::vec {
|
||||
|
||||
template <typename _TElem, typename _TRet, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), _TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
inline _TRet vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
do {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
sum0 = _AddFunc(_MpyFunc(l0, l1), sum0);
|
||||
|
||||
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
sum1 = _AddFunc(_MpyFunc(h0, h1), sum1);
|
||||
|
||||
prev0 = Q6_V_hi_W(curr0);
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
} while (src0_vec_ptr_end - src0_vec_ptr > 1);
|
||||
|
||||
sum = _AddFunc(sum0, sum1);
|
||||
}
|
||||
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = *src0_vec_ptr++;
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
||||
sum = _AddFunc(_MpyFunc(s0, s1), sum);
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
|
||||
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
||||
sum = _AddFunc(_MpyFunc(s0, s1), sum);
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
const size_t leftover_bytes = leftover * sizeof(_TElem);
|
||||
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src0_vec_ptr :
|
||||
prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes), sum);
|
||||
}
|
||||
|
||||
return _ReduceFunc(sum);
|
||||
}
|
||||
|
||||
template <typename _TElem, typename _TRet, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), _TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
inline _TRet vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
|
||||
{
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 3) {
|
||||
HVX_Vector sum2 = Q6_V_vzero();
|
||||
HVX_Vector sum3 = Q6_V_vzero();
|
||||
|
||||
do {
|
||||
HVX_VectorPair curr00 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr10 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr00), Q6_V_lo_W(curr10)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr00), Q6_V_hi_W(curr10)), sum1);
|
||||
|
||||
HVX_VectorPair curr01 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[1];
|
||||
HVX_VectorPair curr11 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[1];
|
||||
sum2 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr01), Q6_V_lo_W(curr11)), sum2);
|
||||
sum3 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr01), Q6_V_hi_W(curr11)), sum3);
|
||||
|
||||
src0_vec_ptr += 4;
|
||||
src1_vec_ptr += 4;
|
||||
} while (src0_vec_ptr_end - src0_vec_ptr > 3);
|
||||
|
||||
sum0 = _AddFunc(sum2, sum0);
|
||||
sum1 = _AddFunc(sum3, sum1);
|
||||
}
|
||||
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr0), Q6_V_lo_W(curr1)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr0), Q6_V_hi_W(curr1)), sum1);
|
||||
}
|
||||
|
||||
sum = _AddFunc(sum0, sum1);
|
||||
}
|
||||
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_Vector curr1 = src1_vec_ptr[0];
|
||||
|
||||
sum = _AddFunc(_MpyFunc(curr0, curr1), sum);
|
||||
}
|
||||
|
||||
return _ReduceFunc(sum);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_mpy_qf32(HVX_Vector src0, HVX_Vector src1) {
|
||||
return Q6_Vqf32_vmpy_VsfVsf(src0, src1);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_add_qf32(HVX_Vector sum, HVX_Vector result) {
|
||||
return Q6_Vqf32_vadd_Vqf32Vqf32(sum, result);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_mpy_qf16(HVX_Vector src0, HVX_Vector src1) {
|
||||
return Q6_Vqf16_vmpy_VhfVhf(src0, src1);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) {
|
||||
return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1, typename _TRet, HVX_Vector_Dual (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
_TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
||||
static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2,
|
||||
"Element size mismatch: _TElem1 must be twice the size of _TElem0");
|
||||
static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0,
|
||||
"Element size mismatch: _TElem1 must be a multiple of _TElem0");
|
||||
|
||||
constexpr const size_t kElementsPerVector0 = hexagon::kBytesPerVector / sizeof(_TElem0);
|
||||
constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1);
|
||||
|
||||
constexpr const __fp16 kOne = 1.0f;
|
||||
const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast<const uint16_t &>(kOne));
|
||||
|
||||
const _TElem0 * const src0_ptr_end = src0 + count;
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1;
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
|
||||
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
do {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV);
|
||||
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
sum0 = _AddFunc(_MpyFunc(s0_pair.first, l1), sum0);
|
||||
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
sum1 = _AddFunc(_MpyFunc(s0_pair.second, h1), sum1);
|
||||
|
||||
prev0 = curr0;
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
src0_vec_ptr++;
|
||||
src1_vec_ptr += 2;
|
||||
} while (src1_vec_ptr_end - src1_vec_ptr > 1);
|
||||
|
||||
sum = _AddFunc(sum0, sum1);
|
||||
}
|
||||
|
||||
const size_t leftover1 = count % kElementsPerVector1;
|
||||
if ((src1_vec_ptr_end - ((HVX_Vector *) src1)) > 0) {
|
||||
// handle the last vector
|
||||
const bool should_fetch_src0 =
|
||||
reinterpret_cast<const _TElem0 *>(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end;
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV);
|
||||
|
||||
const bool has_remaining_src1_vector = src1_vec_ptr_end - src1_vec_ptr > 0;
|
||||
if (has_remaining_src1_vector) {
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
sum = _AddFunc(_MpyFunc(s0_pair.first, s1), sum);
|
||||
prev1 = curr1;
|
||||
}
|
||||
|
||||
bool should_fetch_src1 = leftover1 != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
||||
sum = _AddFunc(_MpyFunc(has_remaining_src1_vector ? s0_pair.second : s0_pair.first, s1), sum);
|
||||
}
|
||||
|
||||
if (leftover1 > 0) {
|
||||
// handle the leftover elements
|
||||
const size_t leftover0 = count % kElementsPerVector0;
|
||||
const size_t leftover_bytes1 = leftover1 * sizeof(_TElem1);
|
||||
HVX_Vector curr0 =
|
||||
reinterpret_cast<const _TElem0 *>(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end ? *src0_vec_ptr : prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = (leftover_bytes1 + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
HVX_Vector_Dual curr0_pair = _ExpandFunc(curr0, kOneV);
|
||||
|
||||
curr0 = leftover1 == leftover0 ? curr0_pair.first : curr0_pair.second;
|
||||
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes1), sum);
|
||||
}
|
||||
|
||||
return _ReduceFunc(sum);
|
||||
}
|
||||
|
||||
template <typename _TElem0, typename _TElem1, typename _TRet, HVX_Vector_Dual (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
||||
_TRet (*_ReduceFunc)(HVX_Vector)>
|
||||
inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
||||
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
||||
static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2,
|
||||
"Element size mismatch: _TElem1 must be twice the size of _TElem0");
|
||||
static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0,
|
||||
"Element size mismatch: _TElem1 must be a multiple of _TElem0");
|
||||
|
||||
constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1);
|
||||
|
||||
constexpr const __fp16 kOne = 1.0f;
|
||||
const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast<const uint16_t &>(kOne));
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1;
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
if (src1_vec_ptr_end - src1_vec_ptr > 3) {
|
||||
HVX_Vector sum2 = Q6_V_vzero();
|
||||
HVX_Vector sum3 = Q6_V_vzero();
|
||||
|
||||
do {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_Vector_Dual curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV);
|
||||
HVX_VectorPair curr10 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
sum0 = _AddFunc(_MpyFunc(curr00.first, Q6_V_lo_W(curr10)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(curr00.second, Q6_V_hi_W(curr10)), sum1);
|
||||
|
||||
HVX_Vector_Dual curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV);
|
||||
HVX_VectorPair curr11 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[1];
|
||||
sum2 = _AddFunc(_MpyFunc(curr01.first, Q6_V_lo_W(curr11)), sum2);
|
||||
sum3 = _AddFunc(_MpyFunc(curr01.second, Q6_V_hi_W(curr11)), sum3);
|
||||
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 4;
|
||||
} while (src1_vec_ptr_end - src1_vec_ptr > 3);
|
||||
|
||||
sum0 = _AddFunc(sum0, sum2);
|
||||
sum1 = _AddFunc(sum1, sum3);
|
||||
}
|
||||
|
||||
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
||||
HVX_Vector curr0 = src0_vec_ptr[0];
|
||||
HVX_Vector_Dual s0_pair = _ExpandFunc(curr0, kOneV);
|
||||
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
sum0 = _AddFunc(_MpyFunc(s0_pair.first, Q6_V_lo_W(curr1)), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(s0_pair.second, Q6_V_hi_W(curr1)), sum1);
|
||||
}
|
||||
|
||||
return _ReduceFunc(_AddFunc(sum0, sum1));
|
||||
}
|
||||
|
||||
template <HVX_Vector (*_Func)(HVX_Vector, HVX_UVector *, HVX_Vector), HVX_Vector (*_FuncScaleConvert)(float),
|
||||
typename _TParam>
|
||||
inline void vec_scale_impl(const _TParam * src, float scale, _TParam * dst, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TParam);
|
||||
|
||||
HVX_Vector * src_vec_ptr = ((HVX_Vector *) src);
|
||||
HVX_Vector * const src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector);
|
||||
HVX_UVector * dst_vec_ptr = ((HVX_UVector *) dst); // TODO: opt the unaligned case?
|
||||
HVX_Vector prev = *src_vec_ptr++;
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
|
||||
HVX_Vector scale_vec = _FuncScaleConvert(scale);
|
||||
|
||||
while (src_vec_end - src_vec_ptr > 1) {
|
||||
HVX_VectorPair curr = reinterpret_cast<HVX_VectorPair *>(src_vec_ptr)[0];
|
||||
src_vec_ptr += 2;
|
||||
|
||||
HVX_Vector lo = Q6_V_valign_VVR(Q6_V_lo_W(curr), prev, (size_t) src);
|
||||
dst_vec_ptr[0] = _Func(lo, dst_vec_ptr, scale_vec);
|
||||
|
||||
HVX_Vector hi = Q6_V_valign_VVR(Q6_V_hi_W(curr), Q6_V_lo_W(curr), (size_t) src);
|
||||
dst_vec_ptr[1] = _Func(hi, dst_vec_ptr + 1, scale_vec);
|
||||
|
||||
dst_vec_ptr += 2;
|
||||
prev = Q6_V_hi_W(curr);
|
||||
}
|
||||
|
||||
if (src_vec_end - src_vec_ptr > 0) {
|
||||
HVX_Vector curr = *src_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec);
|
||||
dst_vec_ptr++;
|
||||
prev = curr;
|
||||
}
|
||||
|
||||
if ((src_vec_end - ((HVX_Vector *) src)) > 0) {
|
||||
// handle the last vector
|
||||
bool should_fetch_next = leftover == 0 && hexagon::is_addr_aligned(src_vec_ptr);
|
||||
HVX_Vector curr = should_fetch_next ? prev : *src_vec_ptr;
|
||||
src_vec_ptr = should_fetch_next ? src_vec_ptr : src_vec_ptr + 1;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec);
|
||||
dst_vec_ptr++;
|
||||
prev = curr;
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
const size_t leftover_bytes = leftover * sizeof(_TParam);
|
||||
HVX_Vector curr =
|
||||
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
|
||||
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _Func(curr, dst_vec_ptr, scale_vec));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename _TData> inline void vec_zero_impl(_TData * src, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TData);
|
||||
|
||||
HVX_UVector * src_vec_ptr = ((HVX_UVector *) src);
|
||||
HVX_UVector * const src_vec_end = ((HVX_UVector *) src) + (count / kElementsPerVector);
|
||||
|
||||
while (src_vec_end - src_vec_ptr > 1) {
|
||||
src_vec_ptr[0] = Q6_V_vzero();
|
||||
src_vec_ptr[1] = Q6_V_vzero();
|
||||
src_vec_ptr += 2;
|
||||
}
|
||||
|
||||
if (src_vec_end - src_vec_ptr > 0) {
|
||||
src_vec_ptr[0] = Q6_V_vzero();
|
||||
src_vec_ptr++;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
const size_t leftover_bytes = leftover * sizeof(_TData);
|
||||
q6op_vstu_variable_ARV(src_vec_ptr, leftover_bytes, Q6_V_vzero());
|
||||
}
|
||||
}
|
||||
|
||||
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector), typename _TyData>
|
||||
inline void vec_trans_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector * dst_vec_ptr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
|
||||
{
|
||||
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(l0, l1);
|
||||
|
||||
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
dst_vec_ptr[1] = _OpBinaryTransform(h0, h1);
|
||||
|
||||
prev0 = Q6_V_hi_W(curr0);
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
dst_vec_ptr += 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = *src0_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1);
|
||||
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
dst_vec_ptr++;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
|
||||
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
dst_vec_ptr[0] = _OpBinaryTransform(s0, s1);
|
||||
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
dst_vec_ptr++;
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
const size_t leftover_bytes = leftover * sizeof(_TyData);
|
||||
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src0_vec_ptr :
|
||||
prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace hexagon::vec
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
#include <rpcmem.h>
|
||||
|
||||
#include "host_device.hpp"
|
||||
#include "profiler.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
namespace {
|
||||
|
|
@ -78,6 +79,8 @@ void backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|||
void backend_buffer_reset(ggml_backend_buffer_t buffer) {
|
||||
auto * buffer_obj = get_buffer_object(buffer);
|
||||
GGML_ASSERT(buffer_obj != nullptr);
|
||||
|
||||
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_reset", (void *) buffer_obj);
|
||||
buffer_obj->clear_tensors();
|
||||
}
|
||||
|
||||
|
|
@ -199,8 +202,8 @@ std::shared_ptr<host_tensor> host_buffer::init_tensor(ggml_tensor * tensor, remo
|
|||
}
|
||||
|
||||
void host_buffer::clear_tensors() {
|
||||
_tensors.clear();
|
||||
LOG_DEBUG("clear host_buffer(%p) tensors\n", (void *) _data);
|
||||
host_tensor::destroy_tensors(_tensors);
|
||||
}
|
||||
|
||||
host_buffer_type::host_buffer_type(ggml_backend_dev_t dev, const std::string & name, common::rpc_mem_ptr rpc_mem) :
|
||||
|
|
|
|||
|
|
@ -57,10 +57,10 @@ bool host_graph::update(ggml_cgraph * cgraph) {
|
|||
_tensor_handles.push_back(tensor_obj->get_device_tensor_handle());
|
||||
_tensor_update_configs.push_back(tensor_obj->update_hosts_params_only(node));
|
||||
|
||||
PROFILER_LOG_DEBUG("node[%d]%s(%s), addr(%p), %s_%ldx%ldx%ldx%ld, handle(%p)\n", i, ggml_get_name(node),
|
||||
ggml_op_desc(node), (void *) tensor_obj, ggml_type_name(node->type),
|
||||
(long) tensor_obj->get_ne(0), (long) tensor_obj->get_ne(1), (long) tensor_obj->get_ne(2),
|
||||
(long) tensor_obj->get_ne(3), (void *) tensor_obj->get_device_tensor_handle());
|
||||
PROFILER_LOG_DEBUG("node[%d]%s(%s), addr(%p), %ldx%ldx%ldx%ld%s, handle(%p)\n", i, ggml_get_name(node),
|
||||
ggml_op_desc(node), (void *) tensor_obj, (long) tensor_obj->get_ne(0),
|
||||
(long) tensor_obj->get_ne(1), (long) tensor_obj->get_ne(2), (long) tensor_obj->get_ne(3),
|
||||
ggml_type_name(node->type), (void *) tensor_obj->get_device_tensor_handle());
|
||||
}
|
||||
|
||||
GGML_ASSERT(_tensor_handles.size() == _tensor_update_configs.size());
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class host_graph {
|
|||
|
||||
private:
|
||||
remote_handle64 _device_handle = 0;
|
||||
npu_device_graph_handle_t _graph_handle = 0;
|
||||
npu_device_graph_handle_t _graph_handle = npu_device_INVALID_DEVICE_GRAPH_HANDLE;
|
||||
std::vector<npu_device_tensor_handle_t> _tensor_handles;
|
||||
std::vector<npu_device_tensor_update_config> _tensor_update_configs;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <list>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "common.hpp"
|
||||
#include "ggml-impl.h"
|
||||
|
|
@ -42,7 +44,7 @@ class host_tensor {
|
|||
auto status = npu_device_tensor_init(_device_handle, &_info, &_device_tensor_handle);
|
||||
if (status != AEE_SUCCESS) {
|
||||
LOG_ERROR("Failed to init tensor: %d", (int) status);
|
||||
_device_tensor_handle = 0;
|
||||
_device_tensor_handle = npu_device_INVALID_DEVICE_TENSOR_HANDLE;
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -66,6 +68,27 @@ class host_tensor {
|
|||
}
|
||||
}
|
||||
|
||||
static void destroy_tensors(std::list<std::shared_ptr<host_tensor>> & tensors) {
|
||||
std::vector<npu_device_tensor_handle_t> handles;
|
||||
|
||||
handles.reserve(tensors.size());
|
||||
remote_handle64 device_handle = 0;
|
||||
|
||||
for (auto tensor : tensors) {
|
||||
if (tensor && tensor->_device_tensor_handle != npu_device_INVALID_DEVICE_TENSOR_HANDLE) {
|
||||
handles.push_back(tensor->_device_tensor_handle);
|
||||
tensor->_device_tensor_handle = npu_device_INVALID_DEVICE_TENSOR_HANDLE; // prevent double free
|
||||
device_handle = tensor->_device_handle;
|
||||
}
|
||||
}
|
||||
|
||||
if (!handles.empty()) {
|
||||
npu_device_tensors_free(device_handle, handles.data(), handles.size());
|
||||
}
|
||||
|
||||
tensors.clear();
|
||||
}
|
||||
|
||||
npu_device_tensor_handle_t get_device_tensor_handle() const { return _device_tensor_handle; }
|
||||
|
||||
void update_params(ggml_tensor * ggml_tensor) {
|
||||
|
|
@ -157,7 +180,7 @@ class host_tensor {
|
|||
return _info_update;
|
||||
}
|
||||
|
||||
bool is_valid() const { return _device_tensor_handle != 0; }
|
||||
bool is_valid() const { return _device_tensor_handle != npu_device_INVALID_DEVICE_TENSOR_HANDLE; }
|
||||
|
||||
int64_t get_ne(size_t index) const {
|
||||
if (index >= DEVICE_TENSOR_MAX_DIMS) {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ interface npu_device : remote_handle64{
|
|||
typedef uint64_t tensor_handle_t;
|
||||
typedef uint64_t graph_handle_t;
|
||||
|
||||
const graph_handle_t INVALID_DEVICE_GRAPH_HANDLE = 0;
|
||||
const tensor_handle_t INVALID_DEVICE_TENSOR_HANDLE = 0;
|
||||
|
||||
typedef uint16_t fp16_t;
|
||||
|
||||
struct block_q4_0 {
|
||||
|
|
@ -107,6 +110,10 @@ interface npu_device : remote_handle64{
|
|||
in tensor_handle_t tensor_handle
|
||||
);
|
||||
|
||||
AEEResult tensors_free(
|
||||
in sequence<tensor_handle_t> tensor_handles
|
||||
);
|
||||
|
||||
AEEResult graph_init(
|
||||
rout graph_handle_t graph_handle
|
||||
);
|
||||
|
|
|
|||
Loading…
Reference in New Issue