322 lines
15 KiB
C++
322 lines
15 KiB
C++
#include "vec_ops.hpp"
|
|
|
|
#include "util.hpp"
|
|
|
|
namespace {
|
|
|
|
template <typename _TElem, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
|
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), float (*_ReduceFunc)(HVX_Vector)>
|
|
inline float vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
|
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
|
|
|
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
|
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
|
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
|
HVX_Vector prev0 = *src0_vec_ptr++;
|
|
HVX_Vector prev1 = *src1_vec_ptr++;
|
|
HVX_Vector sum = Q6_V_vzero();
|
|
HVX_Vector sum0 = Q6_V_vzero();
|
|
HVX_Vector sum1 = Q6_V_vzero();
|
|
|
|
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
|
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
|
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
|
|
|
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
|
|
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
|
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
|
|
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
|
prev0 = Q6_V_hi_W(curr0);
|
|
prev1 = Q6_V_hi_W(curr1);
|
|
src0_vec_ptr += 2;
|
|
src1_vec_ptr += 2;
|
|
|
|
sum0 = _AddFunc(_MpyFunc(l0, l1), sum0);
|
|
sum1 = _AddFunc(_MpyFunc(h0, h1), sum1);
|
|
}
|
|
|
|
sum = _AddFunc(sum0, sum1);
|
|
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
|
HVX_Vector curr0 = *src0_vec_ptr++;
|
|
HVX_Vector curr1 = *src1_vec_ptr++;
|
|
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
|
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
|
prev0 = curr0;
|
|
prev1 = curr1;
|
|
|
|
sum = _AddFunc(_MpyFunc(s0, s1), sum);
|
|
}
|
|
|
|
const size_t leftover = count % kElementsPerVector;
|
|
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
|
|
// handle the last vector
|
|
// see also:
|
|
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
|
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
|
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
|
|
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
|
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
|
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
|
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
|
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
|
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
|
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
|
prev0 = curr0;
|
|
prev1 = curr1;
|
|
|
|
sum = _AddFunc(_MpyFunc(s0, s1), sum);
|
|
}
|
|
|
|
const size_t leftover_bytes = leftover * sizeof(_TElem);
|
|
if (leftover > 0) {
|
|
// handle the leftover elements
|
|
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
|
*src0_vec_ptr :
|
|
prev0;
|
|
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
|
|
|
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
|
*src1_vec_ptr :
|
|
prev1;
|
|
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
|
|
|
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes), sum);
|
|
}
|
|
|
|
return _ReduceFunc(sum);
|
|
}
|
|
|
|
template <typename _TElem, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
|
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), float (*_ReduceFunc)(HVX_Vector)>
|
|
inline float vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
|
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
|
|
|
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
|
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
|
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
|
HVX_Vector sum0 = Q6_V_vzero();
|
|
HVX_Vector sum1 = Q6_V_vzero();
|
|
|
|
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
|
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
|
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
|
src0_vec_ptr += 2;
|
|
src1_vec_ptr += 2;
|
|
|
|
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr0), Q6_V_lo_W(curr1)), sum0);
|
|
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr0), Q6_V_hi_W(curr1)), sum1);
|
|
}
|
|
|
|
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
|
HVX_Vector curr0 = src0_vec_ptr[0];
|
|
HVX_Vector curr1 = src1_vec_ptr[0];
|
|
|
|
sum0 = _AddFunc(_MpyFunc(curr0, curr1), sum0);
|
|
}
|
|
|
|
return _ReduceFunc(_AddFunc(sum0, sum1));
|
|
}
|
|
|
|
inline HVX_Vector vec_mpy_qf32(HVX_Vector src0, HVX_Vector src1) {
|
|
return Q6_Vqf32_vmpy_VsfVsf(src0, src1);
|
|
}
|
|
|
|
inline HVX_Vector vec_add_qf32(HVX_Vector sum, HVX_Vector result) {
|
|
return Q6_Vqf32_vadd_Vqf32Vqf32(sum, result);
|
|
}
|
|
|
|
inline HVX_Vector vec_mpy_qf16(HVX_Vector src0, HVX_Vector src1) {
|
|
return Q6_Vqf16_vmpy_VhfVhf(src0, src1);
|
|
}
|
|
|
|
inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) {
|
|
return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result);
|
|
}
|
|
|
|
template <typename _TElem0, typename _TElem1, HVX_VectorPair (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
|
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
|
float (*_ReduceFunc)(HVX_Vector)>
|
|
inline float vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
|
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
|
static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2,
|
|
"Element size mismatch: _TElem1 must be twice the size of _TElem0");
|
|
static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0,
|
|
"Element size mismatch: _TElem1 must be a multiple of _TElem0");
|
|
|
|
constexpr const size_t kElementsPerVector0 = hexagon::kBytesPerVector / sizeof(_TElem0);
|
|
constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1);
|
|
|
|
constexpr const __fp16 kOne = 1.0f;
|
|
const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast<const uint16_t &>(kOne));
|
|
|
|
const _TElem0 * const src0_ptr_end = src0 + count;
|
|
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
|
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
|
HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1;
|
|
HVX_Vector prev0 = *src0_vec_ptr++;
|
|
HVX_Vector prev1 = *src1_vec_ptr++;
|
|
HVX_Vector sum = Q6_V_vzero();
|
|
HVX_Vector sum0 = Q6_V_vzero();
|
|
HVX_Vector sum1 = Q6_V_vzero();
|
|
|
|
while (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
|
HVX_Vector curr0 = src0_vec_ptr[0];
|
|
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
|
|
|
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
|
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
|
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
|
HVX_VectorPair s0_pair = _ExpandFunc(s0, kOneV);
|
|
prev0 = curr0;
|
|
prev1 = Q6_V_hi_W(curr1);
|
|
src0_vec_ptr++;
|
|
src1_vec_ptr += 2;
|
|
|
|
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), l1), sum0);
|
|
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(s0_pair), h1), sum1);
|
|
}
|
|
|
|
sum = _AddFunc(sum0, sum1);
|
|
const size_t leftover1 = count % kElementsPerVector1;
|
|
if ((src1_vec_ptr_end - ((HVX_Vector *) src1)) > 0) {
|
|
// handle the last vector
|
|
const bool should_fetch_src0 =
|
|
reinterpret_cast<const _TElem0 *>(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end;
|
|
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
|
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
|
|
|
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
|
HVX_VectorPair s0_pair = _ExpandFunc(s0, kOneV);
|
|
|
|
const bool has_remaining_src1_vector = src1_vec_ptr_end - src1_vec_ptr > 0;
|
|
if (has_remaining_src1_vector) {
|
|
HVX_Vector curr1 = *src1_vec_ptr++;
|
|
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
|
prev1 = curr1;
|
|
|
|
// should_handle_last_vector will be always true here
|
|
sum = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), s1), sum);
|
|
}
|
|
|
|
bool should_fetch_src1 = leftover1 != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
|
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
|
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
|
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
|
prev0 = curr0;
|
|
prev1 = curr1;
|
|
|
|
sum = _AddFunc(_MpyFunc(has_remaining_src1_vector ? Q6_V_hi_W(s0_pair) : Q6_V_lo_W(s0_pair), s1), sum);
|
|
}
|
|
|
|
const size_t leftover0 = count % kElementsPerVector0;
|
|
const size_t leftover_bytes1 = leftover1 * sizeof(_TElem1);
|
|
if (leftover1 > 0) {
|
|
// handle the leftover elements
|
|
HVX_Vector curr0 =
|
|
reinterpret_cast<const _TElem0 *>(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end ? *src0_vec_ptr : prev0;
|
|
HVX_Vector curr1 = (leftover_bytes1 + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
|
*src1_vec_ptr :
|
|
prev1;
|
|
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
|
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
|
HVX_VectorPair curr0_pair = _ExpandFunc(curr0, kOneV);
|
|
|
|
curr0 = leftover1 == leftover0 ? Q6_V_lo_W(curr0_pair) : Q6_V_hi_W(curr0_pair);
|
|
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes1), sum);
|
|
}
|
|
|
|
return _ReduceFunc(sum);
|
|
}
|
|
|
|
template <typename _TElem0, typename _TElem1, HVX_VectorPair (*_ExpandFunc)(HVX_Vector, HVX_Vector),
|
|
HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector), HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector),
|
|
float (*_ReduceFunc)(HVX_Vector)>
|
|
inline float vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) {
|
|
static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1");
|
|
static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2,
|
|
"Element size mismatch: _TElem1 must be twice the size of _TElem0");
|
|
static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0,
|
|
"Element size mismatch: _TElem1 must be a multiple of _TElem0");
|
|
|
|
constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1);
|
|
|
|
constexpr const __fp16 kOne = 1.0f;
|
|
const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast<const uint16_t &>(kOne));
|
|
|
|
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
|
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
|
HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1;
|
|
HVX_Vector sum0 = Q6_V_vzero();
|
|
HVX_Vector sum1 = Q6_V_vzero();
|
|
|
|
{
|
|
HVX_Vector sum2 = Q6_V_vzero();
|
|
HVX_Vector sum3 = Q6_V_vzero();
|
|
|
|
while (src1_vec_ptr_end - src1_vec_ptr > 3) {
|
|
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
|
HVX_VectorPair curr10 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
|
HVX_VectorPair curr11 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[1];
|
|
|
|
HVX_VectorPair curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV);
|
|
HVX_VectorPair curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV);
|
|
src0_vec_ptr += 2;
|
|
src1_vec_ptr += 4;
|
|
|
|
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr00), Q6_V_lo_W(curr10)), sum0);
|
|
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr00), Q6_V_hi_W(curr10)), sum1);
|
|
sum2 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr01), Q6_V_lo_W(curr11)), sum2);
|
|
sum3 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr01), Q6_V_hi_W(curr11)), sum3);
|
|
}
|
|
|
|
sum0 = _AddFunc(sum0, sum2);
|
|
sum1 = _AddFunc(sum1, sum3);
|
|
}
|
|
|
|
if (src1_vec_ptr_end - src1_vec_ptr > 1) {
|
|
HVX_Vector curr0 = src0_vec_ptr[0];
|
|
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
|
|
|
HVX_VectorPair s0_pair = _ExpandFunc(curr0, kOneV);
|
|
|
|
sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), Q6_V_lo_W(curr1)), sum0);
|
|
sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(s0_pair), Q6_V_hi_W(curr1)), sum1);
|
|
}
|
|
|
|
return _ReduceFunc(_AddFunc(sum0, sum1));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace hexagon {
|
|
|
|
float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) {
|
|
return vec_dot_product_impl<float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
|
|
}
|
|
|
|
float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) {
|
|
return vec_dot_product_aligned_impl<float, vec_mpy_qf32, vec_add_qf32, vec_reduction_f32_qf32>(src0, src1, count);
|
|
}
|
|
|
|
float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
|
return vec_dot_product_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(src0, src1,
|
|
count);
|
|
}
|
|
|
|
float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
|
return vec_dot_product_aligned_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, vec_reduction_qf16_f32>(
|
|
src0, src1, count);
|
|
}
|
|
|
|
float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
|
return vec_dot_product_mixed_impl<npu_device_fp16_t, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
|
|
vec_reduction_f32_qf32>(src0, src1, count);
|
|
}
|
|
|
|
float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) {
|
|
return vec_dot_product_mix_aligned_impl<npu_device_fp16_t, float, hvx_vsf_convert_vhf, vec_mpy_qf32, vec_add_qf32,
|
|
vec_reduction_f32_qf32>(src0, src1, count);
|
|
}
|
|
|
|
} // namespace hexagon
|