feat: flash attention support for hexagon-npu (#45)
* add flash attn op * expend src tensor size * add flash attn sources * add quantize row functions * make a separated file for vec_dot * wip * wip * refactor: rename quants.hpp includes and add vec_dot to type traits * add flash_attn impl * split vec_scale_f32 * move vec_reduction_qf32 to vec_ops * add vec_scale_f16 * opt * add vec_mad * implement vec_mad_f16 * opt * add op template * opt * add align version * enable flash attn * wip * log print improve * add profiler log * wip * wip * add multi sub proc perf tracker * increase log buffer * remove sub prov pcycle * wip * wip * add prefetch for vec_dot * wip * wip * opt f16 vec dot * opt f16 vecdot * reuse vec_dot_product_impl in vec dot f32 * small opt to unblock pipeline * opt on aligned address wip * Revert "opt on aligned address" This reverts commit 27be1eb61a7d29d2f5fa6f90383e1b5d7fdf9b6a. * add profiler log at thread_pool * wip * invalidate all... * Reapply "opt on aligned address" This reverts commit f075a4c4586e32b7e5819c1fe7f9b6ed218b1767. * add is_constant for tensor config * disable align tensor opt in mul_mat * wip * wip * vec_scale_impl: unrolling the loop * wip * wip * replace reinterpret_cast with direct pointer access for write/read buffers * add fetch * wip * wip * wip * add log * check tensor shape at flash_attn * wip * wip * fix: update tensor type handling in flash_attn_impl * wip * fix: align cache size * fix: qf16->hf * fix: swap order of elements in vector combine for correct scaling * fix: opt f16 scale and mad * fix leftover fetch * wip * load into vector pair * opt cache size calculation in flash_attn_impl * refactoring: hold vtcm at thread local object * wip * add profiler log * mark tensors as modified * restrict tensor invalidation to the first thread in compute_impl * Revert "restrict tensor invalidation to the first thread in compute_impl" This reverts commit 0a8ff2b1bcf366097c16d7437c091382eacbef8b. * invalidate last tensor in compute_impl * invalidate last tensor in compute function * wip * refactor dequantize_row_q4_0 to simplify vector alignment * wip * refactoring: move VTCM quota calculation to thread pool * wip * fix: correct condition check for HEXAGON_SDK_ROOT existence * wip * wip * wip * wip * fix: update condition checks match the naming * fix: improve tensor handling checks and logging in graph and operation implementations * wip
This commit is contained in:
parent
da5dc57872
commit
af620a12f7
|
|
@ -3,6 +3,8 @@ cmake_policy(SET CMP0115 OLD)
|
|||
|
||||
if(DEFINED ENV{HEXAGON_SDK_ROOT})
|
||||
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
|
||||
message("HEXAGON_SDK_ROOT (from environment): ${HEXAGON_SDK_ROOT}")
|
||||
elseif(DEFINED HEXAGON_SDK_ROOT)
|
||||
message("HEXAGON_SDK_ROOT: ${HEXAGON_SDK_ROOT}")
|
||||
else()
|
||||
message(FATAL_ERROR "HEXAGON_SDK_ROOT not defined")
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@
|
|||
#include "graph.hpp"
|
||||
#include "hexagon_npu.h"
|
||||
#include "op_impl.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "remote.h"
|
||||
#include "tensor.hpp"
|
||||
#include "thread_pool.hpp"
|
||||
#include "type_traits.hpp"
|
||||
#include "util.hpp"
|
||||
|
||||
namespace {
|
||||
|
|
@ -124,21 +124,20 @@ int npu_device_close(remote_handle64 h) {
|
|||
|
||||
AEEResult npu_device_device_get_alignment(remote_handle64 _h, uint32_t * alignment) {
|
||||
NPU_UNUSED(_h);
|
||||
*alignment = sizeof(HVX_Vector);
|
||||
*alignment = sizeof(HVX_VectorPair);
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult npu_device_device_support_op(remote_handle64 _h, const npu_device_tensor_spec * src0,
|
||||
const npu_device_tensor_spec * src1, const npu_device_tensor_spec * dst,
|
||||
npu_device_tensor_op op, boolean * is_supported) {
|
||||
AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, int srcsLen, boolean * is_supported) {
|
||||
NPU_UNUSED(_h);
|
||||
|
||||
if (!src0 || !src1 || !dst || !is_supported) {
|
||||
if (!srcs || srcsLen <= 0 || !dst || !is_supported) {
|
||||
DEVICE_LOG_ERROR("npu_device_device_support_op: Invalid arguments");
|
||||
return AEE_EINVARGS;
|
||||
}
|
||||
|
||||
*is_supported = hexagon::support_op(*src0, *src1, *dst, op);
|
||||
*is_supported = hexagon::support_op(op, dst, srcs, srcsLen);
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
|
|
@ -208,19 +207,20 @@ AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h, npu_device_
|
|||
int tensor_paramsLen) {
|
||||
NPU_UNUSED(_h);
|
||||
auto * graph = graph_from_handle(graph_handle);
|
||||
if (!graph || !tensor_handles || tensor_handlesLen <= 0 || !tensor_params ||
|
||||
tensor_handlesLen != tensor_paramsLen) {
|
||||
if (!graph || tensor_handlesLen != tensor_paramsLen || tensor_handlesLen < 0) {
|
||||
return AEE_EINVHANDLE;
|
||||
}
|
||||
|
||||
graph->set_tensor(tensor_handles, tensor_handlesLen);
|
||||
for (int i = 0; i < tensor_handlesLen; ++i) {
|
||||
auto * tensor = tensor_from_handle(tensor_handles[i]);
|
||||
if (tensor) {
|
||||
tensor->update_config(tensor_params[i]);
|
||||
if (tensor_params && tensor_handles) {
|
||||
for (int i = 0; i < tensor_handlesLen; ++i) {
|
||||
auto * tensor = tensor_from_handle(tensor_handles[i]);
|
||||
if (tensor) {
|
||||
tensor->update_config(tensor_params[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
graph->set_tensor(tensor_handles, tensor_handlesLen);
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@
|
|||
namespace hexagon {
|
||||
|
||||
graph::graph() noexcept {
|
||||
_vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size(); // TODO: move to device init?
|
||||
DEVICE_LOG_DEBUG("graph(%p) created: vtcm quota size: %zu\n", (void *) this, _vtcm_quota_size);
|
||||
DEVICE_LOG_DEBUG("graph(%p) created\n", (void *) this);
|
||||
}
|
||||
|
||||
graph::~graph() noexcept {
|
||||
|
|
@ -20,9 +19,10 @@ graph::~graph() noexcept {
|
|||
}
|
||||
|
||||
void graph::set_tensor(const npu_device_tensor_handle_t * tensors, int tensor_count) {
|
||||
if (tensor_count <= 0) {
|
||||
if (tensor_count <= 0 || !tensors) {
|
||||
_tensors.reset();
|
||||
_tensor_count = 0;
|
||||
DEVICE_LOG_DEBUG("graph(%p) set_tensor: no tensors to set\n", (void *) this);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -50,21 +50,27 @@ bool graph::compute(default_thread_pool * thread_pool, const float * f16_to_f32_
|
|||
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]compute", (void *) this);
|
||||
_f16_to_f32_table = f16_to_f32_table;
|
||||
if (thread_pool) {
|
||||
thread_pool->sync_execute(reinterpret_cast<default_thread_pool::task_type>(&graph::thread_pool_task), this);
|
||||
thread_pool->sync_execute(&graph::thread_pool_task, this);
|
||||
} else {
|
||||
compute_impl(nullptr, 0, 1);
|
||||
default_thread_pool::thread_params param = {
|
||||
0, 1, nullptr, hexagon::vtcm_mem::get_avail_block_size()
|
||||
}; // TODO: should have a better way to initialize thread_params
|
||||
|
||||
compute_impl(nullptr, ¶m);
|
||||
}
|
||||
|
||||
_tensors[_tensor_count - 1]->invalidate();
|
||||
_f16_to_f32_table = nullptr;
|
||||
return true;
|
||||
}
|
||||
|
||||
void graph::thread_pool_task(default_thread_pool * pool, size_t thread_idx, size_t thread_count, graph * graph) {
|
||||
graph->compute_impl(pool, thread_idx, thread_count);
|
||||
void graph::thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
|
||||
void * graph) {
|
||||
reinterpret_cast<hexagon::graph *>(graph)->compute_impl(pool, thread_params);
|
||||
}
|
||||
|
||||
void graph::compute_impl(default_thread_pool * pool, size_t thread_idx, size_t thread_count) {
|
||||
hexagon::compute_params params = { thread_idx, thread_count, _vtcm_quota_size / thread_count, _f16_to_f32_table };
|
||||
void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params) {
|
||||
hexagon::compute_params params = { thread_params, _f16_to_f32_table };
|
||||
|
||||
for (size_t i = 0; i < _tensor_count; ++i) {
|
||||
auto * dst = _tensors[i];
|
||||
|
|
@ -78,13 +84,12 @@ void graph::compute_impl(default_thread_pool * pool, size_t thread_idx, size_t t
|
|||
DEVICE_LOG_ERROR("graph(%p) tensor[%zu] op %d compute failed\n", (void *) this, i, op);
|
||||
}
|
||||
|
||||
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu", (void *) this, thread_idx);
|
||||
|
||||
const bool should_sync = requires_thread_barrier(op);
|
||||
if (pool && should_sync && i < _tensor_count - 1) {
|
||||
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this,
|
||||
params.get_thread_index(), i, _tensor_count);
|
||||
pool->sync_thread();
|
||||
}
|
||||
dst->invalidate();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,12 +20,12 @@ class graph {
|
|||
bool compute(default_thread_pool * thread_pool, const float * f16_to_f32_table);
|
||||
|
||||
private:
|
||||
static void thread_pool_task(default_thread_pool * pool, size_t thread_idx, size_t thread_count, graph * graph);
|
||||
void compute_impl(default_thread_pool * pool, size_t thread_idx, size_t thread_count);
|
||||
static void thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
|
||||
void * graph);
|
||||
void compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params);
|
||||
|
||||
std::unique_ptr<tensor *[]> _tensors;
|
||||
size_t _tensor_count = 0;
|
||||
size_t _vtcm_quota_size = 0;
|
||||
const float * _f16_to_f32_table = nullptr;
|
||||
|
||||
DISABLE_COPY_AND_MOVE(graph);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,321 @@
|
|||
|
||||
#include "op_flash_attn.hpp"
|
||||
|
||||
#include "type_traits.hpp"
|
||||
#include "util.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: use a more efficient conversion
|
||||
inline float f16_to_f32(const npu_device_fp16_t src) {
|
||||
return reinterpret_cast<const __fp16 &>(src);
|
||||
}
|
||||
|
||||
// From: ggml/src/ggml-cpu/ops.cpp
|
||||
void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k,
|
||||
const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) {
|
||||
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");
|
||||
|
||||
float scale = out->get_op_param<float>(0);
|
||||
const float max_bias = out->get_op_param<float>(1);
|
||||
const float logit_softcap = out->get_op_param<float>(2);
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
// broadcast factors
|
||||
const int64_t rk2 = q->get_ne(2) / k->get_ne(2);
|
||||
const int64_t rk3 = q->get_ne(3) / k->get_ne(3);
|
||||
const int64_t rv2 = q->get_ne(2) / v->get_ne(2);
|
||||
const int64_t rv3 = q->get_ne(3) / v->get_ne(3);
|
||||
|
||||
const uint32_t n_head = q->get_ne(2);
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const auto q_to_vec_dot = hexagon::get_type_traits(k->get_type()).from_float; // TODO: fix this
|
||||
const auto kq_vec_dot = hexagon::get_type_traits(k->get_type()).vec_dot;
|
||||
const auto v_to_float = hexagon::get_type_traits(v->get_type()).to_float;
|
||||
if (!q_to_vec_dot || !kq_vec_dot) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n");
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t total_rows = q->get_ne(1) * q->get_ne(2) * q->get_ne(3); // total number of rows in Q
|
||||
const auto start_end_row = params->get_work_slice(total_rows); // work slice for this thread
|
||||
|
||||
const auto DK = k->get_ne(0);
|
||||
const auto DV = v->get_ne(0);
|
||||
const auto row_bytes_q = q->get_ne(0) * hexagon::get_type_traits(q->get_type()).type_size;
|
||||
const auto row_bytes_k = DK * hexagon::get_type_traits(k->get_type()).type_size;
|
||||
const auto row_bytes_v = DV * hexagon::get_type_traits(v->get_type()).type_size;
|
||||
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
const auto aligned_dk = (DK + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector;
|
||||
const auto aligned_dv = (DV + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector;
|
||||
size_t total_cache_size = sizeof(float) * (aligned_dk + 2 * aligned_dv);
|
||||
auto * cache_ptr = params->get_vtcm_cache(total_cache_size);
|
||||
if (!cache_ptr) {
|
||||
DEVICE_LOG_ERROR("Failed to allocate VTCM cache for flash_attn: %zu bytes\n", total_cache_size);
|
||||
return;
|
||||
}
|
||||
|
||||
// loop over n_batch and n_head
|
||||
const auto rows_per_batch = q->get_ne(2) * q->get_ne(1);
|
||||
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
|
||||
const bool is_v_f16 =
|
||||
v->get_type() == NPU_DATA_TYPE_F16; // check if V is in FP16 format, otherwise it is in FP32 format
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(out, params->get_thread_index(), flash_attn);
|
||||
const uint8_t * q_ptr = q->get_read_buffer();
|
||||
const uint8_t * k_ptr = k->get_read_buffer();
|
||||
const uint8_t * v_ptr = v->get_read_buffer();
|
||||
const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr;
|
||||
for (auto ir = start_end_row.first; ir < start_end_row.second; ++ir) {
|
||||
// q indices
|
||||
const auto iq3 = ir / rows_per_batch;
|
||||
const auto iq2 = (ir - iq3 * rows_per_batch) / q->get_ne(1);
|
||||
const auto iq1 = (ir - iq3 * rows_per_batch - iq2 * q->get_ne(1));
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope =
|
||||
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
|
||||
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
|
||||
float * V32 = VKQ32 + aligned_dv; // (temporary) FP32 V buffer
|
||||
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
|
||||
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
|
||||
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
|
||||
|
||||
if (is_v_f16) {
|
||||
memset(VKQ16, 0, DV * sizeof(npu_device_fp16_t));
|
||||
} else {
|
||||
memset(VKQ32, 0, DV * sizeof(float));
|
||||
}
|
||||
|
||||
const npu_device_fp16_t * mp =
|
||||
mask_ptr ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1)) : nullptr;
|
||||
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3));
|
||||
if (iq1 < q->get_ne(1) - 1) {
|
||||
hexagon::l2fetch_row(q_data + q->get_nb(1), row_bytes_q);
|
||||
}
|
||||
|
||||
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK, params->f16_to_f32_table);
|
||||
|
||||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
for (int64_t ic = 0; ic < k->get_ne(1); ++ic) {
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 0, loop);
|
||||
float mv = mp ? (slope * f16_to_f32(mp[ic])) : 0.0f;
|
||||
if (mv == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float s = 0.f;
|
||||
{
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 1, kq_dot);
|
||||
const auto * k_data = k_ptr + (ic * k->get_nb(1) + ik2 * k->get_nb(2) + ik3 * k->get_nb(3));
|
||||
if (ic < k->get_ne(1) - 1) {
|
||||
hexagon::l2fetch_row(k_data + k->get_nb(1), row_bytes_k);
|
||||
}
|
||||
|
||||
s = kq_vec_dot(k_data, Q_q, DK); // KQ value
|
||||
s = s * scale; // scale KQ value
|
||||
if (logit_softcap != 0.0f) {
|
||||
s = logit_softcap * tanhf(s); // TODO: vectorize this?
|
||||
}
|
||||
|
||||
s += mv; // apply mask
|
||||
}
|
||||
|
||||
const float Mold = M;
|
||||
|
||||
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
||||
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
||||
|
||||
const auto * v_data = v_ptr + (ic * v->get_nb(1) + iv2 * v->get_nb(2) + iv3 * v->get_nb(3));
|
||||
if (ic < v->get_ne(1)) {
|
||||
hexagon::l2fetch_row(v_data, row_bytes_v);
|
||||
}
|
||||
|
||||
if (is_v_f16) {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
hexagon::vec_scale_f16(VKQ16, ms, VKQ16, DV);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
// V += v*expf(s - M)
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 2, mad);
|
||||
hexagon::vec_mad_f16(reinterpret_cast<const npu_device_fp16_t *>(v_data), vs, VKQ16, DV);
|
||||
} else {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
hexagon::vec_scale_f32(VKQ32, ms, VKQ32, DV);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
// V += v*expf(s - M)
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 2, mad);
|
||||
if (v_to_float) {
|
||||
v_to_float(v_data, V32, DV, params->f16_to_f32_table);
|
||||
hexagon::vec_mad_f32(V32, vs, VKQ32, DV);
|
||||
} else {
|
||||
// V is F32
|
||||
hexagon::vec_mad_f32(reinterpret_cast<const float *>(v_data), vs, VKQ32, DV);
|
||||
}
|
||||
}
|
||||
|
||||
S = S * ms + vs; // scale and increment sum with partial sum
|
||||
}
|
||||
|
||||
if (is_v_f16) {
|
||||
// TODO: use a more efficient conversion
|
||||
for (int64_t d = 0; d < DV; ++d) {
|
||||
VKQ32[d] = f16_to_f32(VKQ16[d]);
|
||||
}
|
||||
}
|
||||
|
||||
// V /= S
|
||||
const float S_inv = 1.0f / S;
|
||||
hexagon::vec_scale_f32(VKQ32, S_inv, VKQ32, DV);
|
||||
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1), VKQ32, out->get_nb(1));
|
||||
}
|
||||
|
||||
out->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool flash_attn_f32(tensor * out, compute_params * params) {
|
||||
if (!out || !params) {
|
||||
DEVICE_LOG_DEBUG("invalid out or params\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto * q = out->get_src(0);
|
||||
const auto * k = out->get_src(1);
|
||||
const auto * v = out->get_src(2);
|
||||
const auto * mask = out->get_src(3);
|
||||
if (!q || !k || !v || !mask) {
|
||||
DEVICE_LOG_DEBUG("invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v,
|
||||
(void *) mask);
|
||||
return false;
|
||||
}
|
||||
|
||||
flash_attn_impl(out, q, k, v, mask, params);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len) {
|
||||
if (op != NPU_OP_FLASH_ATTN) {
|
||||
DEVICE_LOG_DEBUG("op is not NPU_OP_FLASH_ATTN: %d\n", op);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!dst || !srcs || src_len < 4) {
|
||||
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst type is not F32: %s\n", op_get_name(op), get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto * q = &srcs[0];
|
||||
if (q->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]q type is not F32: %s\n", op_get_name(op), get_type_name(q->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto * k = &srcs[1];
|
||||
if (k->type != NPU_DATA_TYPE_F16) { // TODO: support more k types
|
||||
DEVICE_LOG_DEBUG("[%s]k type is not F16: %s\n", op_get_name(op), get_type_name(k->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto * v = &srcs[2];
|
||||
if (v->type != k->type) { // TODO: support more v types
|
||||
DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n", op_get_name(op), get_type_name(v->type),
|
||||
get_type_name(k->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto * mask = &srcs[3];
|
||||
if (mask->type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[%s]mask type is not F16: %s\n", op_get_name(op), get_type_name(mask->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->ne[0] != v->ne[0] || dst->ne[2] != q->ne[1]) {
|
||||
DEVICE_LOG_DEBUG(
|
||||
"[%s]dst shape does not match q and v: dst ne: %ld, %ld, %ld, %ld, q ne: %ld, %ld, %ld, %ld, "
|
||||
"v ne: %ld, %ld, %ld, %ld\n",
|
||||
op_get_name(op), dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3],
|
||||
v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_transposed_or_permuted(dst->nb)) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n", op_get_name(op),
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (q->ne[0] != k->ne[0]) {
|
||||
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
|
||||
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
|
||||
k->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include "op_types.hpp"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool flash_attn_f32(tensor * out, compute_params * params);
|
||||
bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -2,13 +2,12 @@
|
|||
|
||||
#include "op_impl.hpp"
|
||||
|
||||
#include <hexagon_types.h>
|
||||
#include <HTP/core/intrinsics.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "op_flash_attn.hpp"
|
||||
#include "op_mul_mat.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "type_traits.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -16,12 +15,12 @@ template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector), typename _TyData>
|
|||
inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
|
||||
|
||||
HVX_Vector * iptr0 = ((HVX_Vector *) src0);
|
||||
HVX_Vector * iptr0_end = ((HVX_Vector *) src0) + (count / kElementsPerVector);
|
||||
HVX_Vector * iptr1 = ((HVX_Vector *) src1);
|
||||
HVX_Vector * optr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
|
||||
HVX_Vector prev0 = *iptr0++;
|
||||
HVX_Vector prev1 = *iptr1++;
|
||||
HVX_Vector * iptr0 = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const iptr0_end = ((HVX_Vector *) src0) + (count / kElementsPerVector);
|
||||
HVX_Vector * iptr1 = ((HVX_Vector *) src1);
|
||||
HVX_Vector * optr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
|
||||
HVX_Vector prev0 = *iptr0++;
|
||||
HVX_Vector prev1 = *iptr1++;
|
||||
|
||||
while (iptr0 < iptr0_end) {
|
||||
HVX_Vector curr0 = *iptr0++;
|
||||
|
|
@ -33,25 +32,25 @@ inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count
|
|||
prev1 = curr1;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
if ((iptr0_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool iptr0_aligned = hexagon::is_addr_aligned(iptr0);
|
||||
HVX_Vector curr0 = iptr0_aligned ? prev0 : *iptr0;
|
||||
iptr0 = iptr0_aligned ? iptr0 : iptr0 + 1;
|
||||
bool iptr1_aligned = hexagon::is_addr_aligned(iptr1);
|
||||
HVX_Vector curr1 = iptr1_aligned ? prev1 : *iptr1;
|
||||
iptr1 = iptr1_aligned ? iptr1 : iptr1 + 1;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
*optr++ = _OpIntrinsic(s0, s1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(iptr0);
|
||||
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(iptr1);
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *iptr0 : prev0;
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *iptr1 : prev1;
|
||||
iptr0 += should_fetch_src0 ? 1 : 0;
|
||||
iptr1 += should_fetch_src1 ? 1 : 0;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
*optr++ = _OpIntrinsic(s0, s1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
const size_t leftover_bytes = leftover * sizeof(_TyData);
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
|
|
@ -136,18 +135,23 @@ template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::co
|
|||
return false;
|
||||
}
|
||||
|
||||
const auto * src0_ptr = reinterpret_cast<const uint8_t *>(src0->get_read_buffer());
|
||||
const auto * src1_ptr = reinterpret_cast<const uint8_t *>(src1->get_read_buffer());
|
||||
auto * dst_ptr = reinterpret_cast<uint8_t *>(out->get_write_buffer());
|
||||
auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
|
||||
const auto rows_per_cube = out->get_ne(2) * out->get_ne(1);
|
||||
const auto start_end = hexagon::get_thread_work_slice(total_rows, params->tidx, params->tcnt);
|
||||
uint8_t * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("element_wise_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
|
||||
const auto rows_per_cube = out->get_ne(2) * out->get_ne(1);
|
||||
const auto start_end = params->get_work_slice(total_rows);
|
||||
if (start_end.first >= start_end.second) {
|
||||
return true;
|
||||
}
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->tidx);
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
|
||||
|
||||
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type);
|
||||
for (int64_t ir = start_end.first; ir < start_end.second; ++ir) {
|
||||
|
|
@ -171,6 +175,7 @@ template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::co
|
|||
static_cast<size_t>(out->get_ne(0)), reinterpret_cast<data_type *>(dst_row));
|
||||
}
|
||||
|
||||
out->release_write_buffer(); // mark the output tensor as modified
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -184,27 +189,36 @@ bool is_same_shape(const npu_device_tensor_spec & src, const npu_device_tensor_s
|
|||
return true;
|
||||
}
|
||||
|
||||
bool is_element_wise_op_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1,
|
||||
const npu_device_tensor_spec & dst, npu_device_tensor_op op) {
|
||||
bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len) {
|
||||
if (op != NPU_OP_ADD && op != NPU_OP_SUB && op != NPU_OP_MUL) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst.type != src0.type || dst.type != src1.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst.type));
|
||||
if (!dst || !srcs || src_len < 2) {
|
||||
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst.type != NPU_DATA_TYPE_F32 && dst.type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst.type));
|
||||
const auto & src0 = srcs[0];
|
||||
const auto & src1 = srcs[1];
|
||||
if (dst->type != src0.type || dst->type != src1.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: fix FP16 add/sub
|
||||
if (dst.type == NPU_DATA_TYPE_F16 && op != NPU_OP_MUL) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst.type));
|
||||
if (dst->type == NPU_DATA_TYPE_F16 && op != NPU_OP_MUL) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -214,7 +228,7 @@ bool is_element_wise_op_supported(const npu_device_tensor_spec & src0, const npu
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!is_same_shape(src0, dst)) {
|
||||
if (!is_same_shape(src0, *dst)) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -225,10 +239,10 @@ bool is_element_wise_op_supported(const npu_device_tensor_spec & src0, const npu
|
|||
void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
|
||||
HVX_Vector * src_vec_ptr = ((HVX_Vector *) src);
|
||||
HVX_Vector * src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector);
|
||||
HVX_Vector prev = *src_vec_ptr++;
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
HVX_Vector * src_vec_ptr = ((HVX_Vector *) src);
|
||||
HVX_Vector * const src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector);
|
||||
HVX_Vector prev = *src_vec_ptr++;
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
while (src_vec_ptr < src_vec_end) {
|
||||
HVX_Vector curr = *src_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
|
|
@ -236,17 +250,17 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
|
|||
prev = curr;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
if ((src_vec_end - ((HVX_Vector *) src)) > 0) {
|
||||
// handle the last vector
|
||||
bool src_ptr_aligned = hexagon::is_addr_aligned(src_vec_ptr);
|
||||
HVX_Vector curr = src_ptr_aligned ? prev : *src_vec_ptr;
|
||||
src_vec_ptr = src_ptr_aligned ? src_vec_ptr : src_vec_ptr + 1;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(s0, s0));
|
||||
prev = curr;
|
||||
bool should_fetch_src = leftover != 0 || !hexagon::is_addr_aligned(src_vec_ptr);
|
||||
HVX_Vector curr = should_fetch_src ? *src_vec_ptr : prev;
|
||||
src_vec_ptr += should_fetch_src ? 1 : 0;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(s0, s0));
|
||||
prev = curr;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
const size_t leftover_bytes = leftover * sizeof(float);
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
|
|
@ -257,37 +271,9 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
|
|||
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes));
|
||||
}
|
||||
|
||||
const float mean = hexagon::vec_reduction_f32(sum) / count; // TODO: figure out how to do division in vector
|
||||
const float scale = 1.0f / sqrtf(mean + eps); // TODO: use buildin blas sqrtf?
|
||||
|
||||
HVX_Vector scale_vec = Q6_V_vsplat_R(reinterpret_cast<const uint32_t &>(scale));
|
||||
src_vec_ptr = ((HVX_Vector *) src);
|
||||
prev = *src_vec_ptr++;
|
||||
HVX_Vector * dst_vec_ptr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
|
||||
while (src_vec_ptr < src_vec_end) {
|
||||
HVX_Vector curr = *src_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
*dst_vec_ptr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(s0, scale_vec));
|
||||
prev = curr;
|
||||
}
|
||||
|
||||
if ((src_vec_end - ((HVX_Vector *) src)) > 0) {
|
||||
// handle the last vector
|
||||
bool src_ptr_aligned = hexagon::is_addr_aligned(src_vec_ptr);
|
||||
HVX_Vector curr = src_ptr_aligned ? prev : *src_vec_ptr;
|
||||
src_vec_ptr = src_ptr_aligned ? src_vec_ptr : src_vec_ptr + 1;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
*dst_vec_ptr++ = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(s0, scale_vec));
|
||||
prev = curr;
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr =
|
||||
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
|
||||
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(curr, scale_vec)));
|
||||
}
|
||||
const float mean = hexagon::vec_reduction_qf32_f32(sum) / count; // TODO: figure out how to do division in vector
|
||||
const float scale = 1.0f / sqrtf(mean + eps); // TODO: use buildin blas sqrtf?
|
||||
hexagon::vec_scale_f32(src, scale, dst, count);
|
||||
}
|
||||
|
||||
// TODO: merge with element_wise_op?
|
||||
|
|
@ -305,16 +291,22 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
|
|||
return true; // skip if no src
|
||||
}
|
||||
|
||||
const auto * src0_ptr = reinterpret_cast<const uint8_t *>(src0->get_read_buffer());
|
||||
auto * dst_ptr = reinterpret_cast<uint8_t *>(out->get_write_buffer());
|
||||
auto * dst_ptr = out->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("unary_op: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
|
||||
hexagon::get_type_name(out->get_type()));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto * src0_ptr = src0->get_read_buffer();
|
||||
auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
|
||||
const auto rows_per_cube = out->get_ne(2) * out->get_ne(1);
|
||||
const auto start_end = hexagon::get_thread_work_slice(total_rows, params->tidx, params->tcnt);
|
||||
const auto start_end = params->get_work_slice(total_rows);
|
||||
if (start_end.first >= start_end.second) {
|
||||
return true;
|
||||
}
|
||||
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->tidx);
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
|
||||
|
||||
const auto param = out->get_op_param<param_type>(0);
|
||||
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type);
|
||||
|
|
@ -333,28 +325,36 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
|
|||
reinterpret_cast<data_type *>(dst_row));
|
||||
}
|
||||
|
||||
out->release_write_buffer(); // mark the output tensor as modified
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_unary_op_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1,
|
||||
const npu_device_tensor_spec & dst, npu_device_tensor_op op) {
|
||||
bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len) {
|
||||
if (op != NPU_OP_RMS_NORM) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst.type != src0.type) {
|
||||
if (!dst || !srcs || src_len < 1) {
|
||||
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto & src0 = srcs[0];
|
||||
if (dst->type != src0.type) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst.type));
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst.type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst.type));
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
|
||||
hexagon::get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_same_shape(src0, dst)) {
|
||||
if (!is_same_shape(src0, *dst)) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -371,40 +371,47 @@ struct op_capabilities {
|
|||
|
||||
constexpr const op_capabilities kOpCapabilities[] = {
|
||||
{
|
||||
NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported,
|
||||
NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported,
|
||||
{
|
||||
hexagon::mul_mat_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, true,
|
||||
},
|
||||
}, true, // requires_thread_barrier
|
||||
},
|
||||
{
|
||||
NPU_OP_ADD, is_element_wise_op_supported,
|
||||
NPU_OP_ADD, is_element_wise_op_supported,
|
||||
{
|
||||
element_wise_op<vec_op_f32_f32<vadd_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vadd_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, false,
|
||||
},
|
||||
}, false, // requires_thread_barrier
|
||||
},
|
||||
{
|
||||
NPU_OP_SUB, is_element_wise_op_supported,
|
||||
NPU_OP_SUB, is_element_wise_op_supported,
|
||||
{
|
||||
element_wise_op<vec_op_f32_f32<vsub_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vsub_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, false,
|
||||
},
|
||||
}, false, // requires_thread_barrier
|
||||
},
|
||||
{
|
||||
NPU_OP_MUL, is_element_wise_op_supported,
|
||||
NPU_OP_MUL, is_element_wise_op_supported,
|
||||
{
|
||||
element_wise_op<vec_op_f32_f32<vmul_f32_f32>>, // NPU_DATA_TYPE_F32
|
||||
element_wise_op<vec_op_f16_f16<vmul_f16_f16>>, // NPU_DATA_TYPE_F16
|
||||
}, false,
|
||||
},
|
||||
}, false, // requires_thread_barrier
|
||||
},
|
||||
{
|
||||
NPU_OP_RMS_NORM, is_unary_op_supported,
|
||||
NPU_OP_RMS_NORM, is_unary_op_supported,
|
||||
{
|
||||
unary_op<rms_norm_vec_f32>, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, false,
|
||||
},
|
||||
}, false, // requires_thread_barrier
|
||||
},
|
||||
{
|
||||
NPU_OP_FLASH_ATTN,hexagon::is_flash_attn_supported,
|
||||
{
|
||||
hexagon::flash_attn_f32, // NPU_DATA_TYPE_F32
|
||||
nullptr, // NPU_DATA_TYPE_F16
|
||||
}, true, // requires_thread_barrier
|
||||
},
|
||||
};
|
||||
|
||||
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
|
||||
|
|
@ -415,6 +422,8 @@ static_assert(kOpCapabilities[NPU_OP_MUL_MAT].op == NPU_OP_MUL_MAT, "kOpArray[NP
|
|||
static_assert(kOpCapabilities[NPU_OP_MUL].op == NPU_OP_MUL, "kOpArray[NPU_OP_MUL].op != NPU_OP_MUL");
|
||||
static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
|
||||
"kOpArray[NPU_OP_RMS_NORM].op != NPU_OP_RMS_NORM");
|
||||
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
|
||||
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
|
||||
|
||||
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
|
||||
if (op >= NPU_OP_COUNT) {
|
||||
|
|
@ -440,16 +449,16 @@ bool requires_thread_barrier(npu_device_tensor_op op) {
|
|||
return kOpCapabilities[op].requires_thread_barrier;
|
||||
}
|
||||
|
||||
bool support_op(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1,
|
||||
const npu_device_tensor_spec & dst, npu_device_tensor_op op) {
|
||||
if (get_compute_func_impl(op, dst.type) == nullptr) {
|
||||
bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
|
||||
size_t src_len) {
|
||||
if (get_compute_func_impl(op, dst->type) == nullptr) {
|
||||
DEVICE_LOG_ERROR("[%s]unsupported, get_compute_func failed\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
auto is_supported_func = kOpCapabilities[op].is_supported;
|
||||
if (!is_supported_func || !is_supported_func(src0, src1, dst, op)) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func failed\n", op_get_name(op));
|
||||
if (!is_supported_func || !is_supported_func(op, dst, srcs, src_len)) {
|
||||
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ compute_func_type get_compute_func(tensor * dst);
|
|||
|
||||
bool requires_thread_barrier(npu_device_tensor_op op);
|
||||
|
||||
bool support_op(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1,
|
||||
const npu_device_tensor_spec & dst, npu_device_tensor_op op);
|
||||
bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
|
||||
size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -1,167 +1,12 @@
|
|||
#include "op_mul_mat.hpp"
|
||||
|
||||
#include <HTP/core/intrinsics.h>
|
||||
|
||||
#include "quants.hpp"
|
||||
#include "thread_pool.hpp" // TODO: remove this dependency
|
||||
#include "type_traits.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
#include "vtcm_mem.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
inline float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
|
||||
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_Vector curr0_lo = src0_vec_ptr[0];
|
||||
HVX_Vector curr0_hi = src0_vec_ptr[1];
|
||||
HVX_Vector curr1_lo = src1_vec_ptr[0];
|
||||
HVX_Vector curr1_hi = src1_vec_ptr[1];
|
||||
|
||||
HVX_Vector l0 = Q6_V_valign_VVR(curr0_lo, prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(curr1_lo, prev1, (size_t) src1);
|
||||
HVX_Vector h0 = Q6_V_valign_VVR(curr0_hi, curr0_lo, (size_t) src0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(curr1_hi, curr1_lo, (size_t) src1);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_Vqf32_vmpy_VsfVsf(l0, l1), sum);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_Vqf32_vmpy_VsfVsf(h0, h1), sum);
|
||||
|
||||
prev0 = curr0_hi;
|
||||
prev1 = curr1_hi;
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
}
|
||||
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = *src0_vec_ptr++;
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_Vqf32_vmpy_VsfVsf(s0, s1), sum);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
}
|
||||
|
||||
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool iptr0_aligned = hexagon::is_addr_aligned(src0_vec_ptr);
|
||||
HVX_Vector curr0 = iptr0_aligned ? prev0 : *src0_vec_ptr;
|
||||
src0_vec_ptr = iptr0_aligned ? src0_vec_ptr : src0_vec_ptr + 1;
|
||||
bool iptr1_aligned = hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
HVX_Vector curr1 = iptr1_aligned ? prev1 : *src1_vec_ptr;
|
||||
src1_vec_ptr = iptr1_aligned ? src1_vec_ptr : src1_vec_ptr + 1;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_Vqf32_vmpy_VsfVsf(s0, s1), sum);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
const size_t leftover_bytes = leftover * sizeof(float);
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src0_vec_ptr :
|
||||
prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vqf32(
|
||||
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr0, curr1), Q6_V_vzero(), leftover_bytes), sum);
|
||||
}
|
||||
|
||||
return hexagon::vec_reduction_f32(sum);
|
||||
}
|
||||
|
||||
// TODO: merge with vec_dot_product_f32_f32?
|
||||
inline float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(npu_device_fp16_t);
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * src0_vec_ptr_end = ((HVX_Vector *) src0) + (count / kElementsPerVector);
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
HVX_Vector sum_hi = Q6_V_vzero();
|
||||
HVX_Vector sum_lo = Q6_V_vzero();
|
||||
|
||||
while (src0_vec_ptr < src0_vec_ptr_end) {
|
||||
HVX_Vector curr0 = *src0_vec_ptr++;
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(s0, s1);
|
||||
sum_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(result), sum_hi);
|
||||
sum_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(result), sum_lo);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
}
|
||||
|
||||
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool iptr0_aligned = hexagon::is_addr_aligned(src0_vec_ptr);
|
||||
HVX_Vector curr0 = iptr0_aligned ? prev0 : *src0_vec_ptr;
|
||||
src0_vec_ptr = iptr0_aligned ? src0_vec_ptr : src0_vec_ptr + 1;
|
||||
bool iptr1_aligned = hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
HVX_Vector curr1 = iptr1_aligned ? prev1 : *src1_vec_ptr;
|
||||
src1_vec_ptr = iptr1_aligned ? src1_vec_ptr : src1_vec_ptr + 1;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(s0, s1);
|
||||
sum_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(result), sum_hi);
|
||||
sum_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(result), sum_lo);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
const size_t leftover_bytes = leftover * sizeof(npu_device_fp16_t);
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src0_vec_ptr :
|
||||
prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
HVX_VectorPair result = Q6_Wqf32_vmpy_VhfVhf(curr0, curr1);
|
||||
|
||||
// TODO: can we do this better?
|
||||
if (leftover > kFloatsPerVector) {
|
||||
sum_hi = Q6_Vqf32_vadd_Vqf32Vqf32(
|
||||
Q6_V_valign_VVR(Q6_V_hi_W(result), Q6_V_vzero(), (leftover % kFloatsPerVector) * sizeof(float)),
|
||||
sum_hi);
|
||||
sum_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(result), sum_lo);
|
||||
} else {
|
||||
sum_lo = Q6_Vqf32_vadd_Vqf32Vqf32(
|
||||
Q6_V_valign_VVR(Q6_V_lo_W(result), Q6_V_vzero(), leftover * sizeof(float)), sum_lo);
|
||||
}
|
||||
}
|
||||
|
||||
return hexagon::vec_reduction_f32(Q6_Vqf32_vadd_Vqf32Vqf32(sum_hi, sum_lo));
|
||||
}
|
||||
|
||||
template <typename T> struct get_data_type {};
|
||||
|
||||
template <typename _TyData> struct get_data_type<float (*)(const _TyData *, const _TyData *, size_t)> {
|
||||
|
|
@ -175,29 +20,26 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
|
||||
const bool is_quantized = hexagon::is_quantized_type(src0->get_type());
|
||||
const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0);
|
||||
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).dequantize_row;
|
||||
auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float;
|
||||
if (is_quantized && dequantize_row_func == nullptr) {
|
||||
DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type());
|
||||
return;
|
||||
}
|
||||
|
||||
const auto r02 = src1->get_ne(2) / src0->get_ne(2);
|
||||
const auto r03 = src1->get_ne(3) / src0->get_ne(3);
|
||||
const auto * src0_ptr = reinterpret_cast<const uint8_t *>(src0->get_read_buffer());
|
||||
const auto * src1_ptr = reinterpret_cast<const uint8_t *>(src1->get_read_buffer());
|
||||
auto * dst_ptr = reinterpret_cast<uint8_t *>(dst->get_write_buffer());
|
||||
const auto total_planes = dst->get_ne(3) * dst->get_ne(2);
|
||||
const auto r02 = src1->get_ne(2) / src0->get_ne(2);
|
||||
const auto r03 = src1->get_ne(3) / src0->get_ne(3);
|
||||
const auto total_planes = dst->get_ne(3) * dst->get_ne(2);
|
||||
|
||||
auto start_end_plane = std::pair<int64_t, int64_t>{ 0, total_planes };
|
||||
auto start_end_row = std::pair<int64_t, int64_t>{ 0, dst->get_ne(1) };
|
||||
auto start_end_element = std::pair<int64_t, int64_t>{ 0, dst->get_ne(0) };
|
||||
|
||||
if (total_planes >= params->tcnt) {
|
||||
start_end_plane = hexagon::get_thread_work_slice(total_planes, params->tidx, params->tcnt);
|
||||
} else if (dst->get_ne(1) >= params->tcnt) {
|
||||
start_end_row = hexagon::get_thread_work_slice(dst->get_ne(1), params->tidx, params->tcnt);
|
||||
if (total_planes >= params->get_thread_count()) {
|
||||
start_end_plane = params->get_work_slice(total_planes);
|
||||
} else if (dst->get_ne(1) >= params->get_thread_count()) {
|
||||
start_end_row = params->get_work_slice(dst->get_ne(1));
|
||||
} else {
|
||||
start_end_element = hexagon::get_thread_work_slice(dst->get_ne(0), params->tidx, params->tcnt);
|
||||
start_end_element = params->get_work_slice(dst->get_ne(0));
|
||||
}
|
||||
|
||||
if (start_end_plane.second <= start_end_plane.first || start_end_row.second <= start_end_row.first ||
|
||||
|
|
@ -218,7 +60,7 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
bool is_mem_cache = false;
|
||||
if (is_quantized) {
|
||||
src0_plane_slice_row_count =
|
||||
std::min(params->vtcm_quota_size / src0_actual_row_size, src0_plane_slice_row_count);
|
||||
std::min(params->get_vtcm_quota_size() / src0_actual_row_size, src0_plane_slice_row_count);
|
||||
src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count;
|
||||
src0_plane_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size);
|
||||
if (src0_plane_cache_ptr == nullptr) {
|
||||
|
|
@ -238,7 +80,17 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
src0_plane_cache_size);
|
||||
|
||||
const size_t valid_row_bytes = src1->get_ne(0) * sizeof(data_type);
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(dst, params->tidx, dequant);
|
||||
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(dst, params->get_thread_index(), dequant);
|
||||
|
||||
uint8_t * dst_ptr = dst->get_write_buffer();
|
||||
if (!dst_ptr) {
|
||||
DEVICE_LOG_ERROR("mul_mat_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) dst,
|
||||
hexagon::get_type_name(dst->get_type()));
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * src0_ptr = src0->get_read_buffer();
|
||||
const uint8_t * src1_ptr = src1->get_read_buffer();
|
||||
for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) {
|
||||
const auto i3 = ip / dst->get_ne(2);
|
||||
const auto i2 = ip - i3 * dst->get_ne(2);
|
||||
|
|
@ -289,6 +141,8 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst->release_write_buffer(); // mark the output tensor as modified
|
||||
}
|
||||
|
||||
bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) {
|
||||
|
|
@ -299,7 +153,7 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n
|
|||
}
|
||||
|
||||
const auto type_traits = hexagon::get_type_traits(src0.type);
|
||||
if (!type_traits.is_quantized || type_traits.dequantize_row == nullptr) {
|
||||
if (!type_traits.is_quantized || type_traits.to_float == nullptr) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src0 is not quantized\n",
|
||||
hexagon::get_type_name(src0.type), hexagon::get_type_name(src1.type));
|
||||
return false;
|
||||
|
|
@ -311,7 +165,7 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n
|
|||
return false;
|
||||
}
|
||||
|
||||
const auto vtcm_thread_quota_size = hexagon::vtcm_mem::get_total_size() / hexagon::kMaxThreadCount;
|
||||
const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota();
|
||||
if (src0.ne[0] * sizeof(hexagon::dequantized_element_type) > vtcm_thread_quota_size) {
|
||||
DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n",
|
||||
hexagon::get_type_name(src0.type), (long) src0.ne[0], vtcm_thread_quota_size);
|
||||
|
|
@ -339,14 +193,13 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
|
|||
return true; // skip if no src
|
||||
}
|
||||
|
||||
// TODO: array?
|
||||
switch (src1->get_type()) {
|
||||
case NPU_DATA_TYPE_F32:
|
||||
mul_mat_impl<vec_dot_product_f32_f32>(src0, src1, out, params);
|
||||
mul_mat_impl<hexagon::vec_dot_product_f32_f32>(src0, src1, out, params);
|
||||
return true;
|
||||
|
||||
case NPU_DATA_TYPE_F16:
|
||||
mul_mat_impl<vec_dot_product_f16_f16>(src0, src1, out, params);
|
||||
mul_mat_impl<hexagon::vec_dot_product_f16_f16>(src0, src1, out, params);
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
|
|
@ -356,18 +209,25 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool is_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1,
|
||||
const npu_device_tensor_spec & dst, npu_device_tensor_op op) {
|
||||
bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len) {
|
||||
if (op != NPU_OP_MUL_MAT) {
|
||||
DEVICE_LOG_DEBUG("op is not MUL_MAT: %d\n", op);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst.type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst type is not F32: %s\n", op_get_name(op), get_type_name(dst.type));
|
||||
if (!dst || !srcs || src_len < 2) {
|
||||
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != NPU_DATA_TYPE_F32) {
|
||||
DEVICE_LOG_DEBUG("[%s]dst type is not F32: %s\n", op_get_name(op), get_type_name(dst->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto & src0 = srcs[0];
|
||||
const auto & src1 = srcs[1];
|
||||
if (src0.type != src1.type) {
|
||||
#ifdef GGML_HEXAGON_ENABLE_QUANTIZED_TENSORS
|
||||
if (!is_quantized_mul_mat_supported(src0, src1)) {
|
||||
|
|
@ -380,15 +240,15 @@ bool is_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_
|
|||
#endif
|
||||
}
|
||||
|
||||
if (src0.ne[0] != src1.ne[0] || src0.ne[1] != dst.ne[0]) {
|
||||
if (src0.ne[0] != src1.ne[0] || src0.ne[1] != dst->ne[0]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src0 and src1 cannot multiply: %ldx%ld vs %ldx%ld\n", op_get_name(op), (long) src0.ne[0],
|
||||
(long) src0.ne[1], (long) src1.ne[0], (long) src1.ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1.ne[1] != dst.ne[1] || src1.ne[2] != dst.ne[2] || src1.ne[3] != dst.ne[3]) {
|
||||
if (src1.ne[1] != dst->ne[1] || src1.ne[2] != dst->ne[2] || src1.ne[3] != dst->ne[3]) {
|
||||
DEVICE_LOG_DEBUG("[%s]src1 and dst dimensions not match: %ldx%ld vs %ldx%ld\n", op_get_name(op),
|
||||
(long) src1.ne[2], (long) src1.ne[3], (long) dst.ne[2], (long) dst.ne[3]);
|
||||
(long) src1.ne[2], (long) src1.ne[3], (long) dst->ne[2], (long) dst->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,64 +7,8 @@
|
|||
|
||||
namespace hexagon {
|
||||
|
||||
inline size_t unaligned_bytes(const void * addr) {
|
||||
return ((size_t) addr) & kAlignMask;
|
||||
}
|
||||
|
||||
inline bool is_addr_aligned(void * addr) {
|
||||
return unaligned_bytes(addr) == 0;
|
||||
}
|
||||
|
||||
inline void l2fetch(const void * p, uint32_t stride, uint32_t width, uint32_t height, uint32_t dir) {
|
||||
uint64_t control = HEXAGON_V64_CREATE_H(dir, stride, width, height);
|
||||
__asm__ __volatile__(" l2fetch(%0,%1) " : : "r"(p), "r"(control));
|
||||
}
|
||||
|
||||
inline void l2fetch_row(const uint8_t * curr_row, size_t bytes) {
|
||||
// TODO: should we use small kL2FetchAheadVectors?
|
||||
int32_t l2fetch_vectors = Q6_R_min_RR(bytes / kBytesPerVector, kL2FetchAheadVectors);
|
||||
hexagon::l2fetch(curr_row, kBytesPerVector, kBytesPerVector, l2fetch_vectors, 0);
|
||||
}
|
||||
|
||||
inline float get_flt0_from_fltv(HVX_Vector vect) {
|
||||
// See also: tools\HEXAGON_Tools\8.6.07\Examples\StandAlone_Applications\QFloat\QFloat.c
|
||||
|
||||
union {
|
||||
int32_t i;
|
||||
float f;
|
||||
} cvt;
|
||||
|
||||
cvt.i = vect[0];
|
||||
return cvt.f;
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) {
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
static_assert(kFloatsPerVector == 32 || kFloatsPerVector == 16, "kFloatsPerVector should be 16 or 32");
|
||||
|
||||
// TODO: do we have a better way to do the reduction?
|
||||
switch (kFloatsPerVector) {
|
||||
default:
|
||||
case 32:
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 16 * sizeof(float)));
|
||||
// fallthrough
|
||||
case 16:
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 8 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 4 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 2 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, sizeof(float)));
|
||||
break;
|
||||
}
|
||||
|
||||
return sums;
|
||||
}
|
||||
|
||||
inline float vec_reduction_f32(HVX_Vector sums) {
|
||||
return hexagon::get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec_reduction_qf32(sums)));
|
||||
}
|
||||
|
||||
bool mul_mat_f32(tensor * out, compute_params * params);
|
||||
bool is_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1,
|
||||
const npu_device_tensor_spec & dst, npu_device_tensor_op op);
|
||||
bool is_mul_mat_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -9,46 +9,12 @@
|
|||
|
||||
#include "hexagon_npu.h"
|
||||
#include "tensor.hpp"
|
||||
#include "thread_pool.hpp"
|
||||
#include "util.hpp"
|
||||
#include "vtcm_mem.hpp"
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
struct compute_params {
|
||||
const size_t tidx;
|
||||
const size_t tcnt;
|
||||
const size_t vtcm_quota_size;
|
||||
const float * f16_to_f32_table;
|
||||
std::unique_ptr<hexagon::vtcm_mem> vtcm_cache;
|
||||
std::unique_ptr<uint8_t[]> mem_cache;
|
||||
size_t mem_cache_size = 0;
|
||||
|
||||
uint8_t * get_vtcm_cache(size_t size) {
|
||||
if (!vtcm_cache || vtcm_cache->get_size() < size) {
|
||||
vtcm_cache = std::make_unique<hexagon::vtcm_mem>(size, false);
|
||||
}
|
||||
|
||||
if (!vtcm_cache->is_valid()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return vtcm_cache->get_mem();
|
||||
}
|
||||
|
||||
uint8_t * get_mem_cache(size_t size) {
|
||||
if (!mem_cache || mem_cache_size < size) {
|
||||
mem_cache = std::make_unique<uint8_t[]>(size + 256);
|
||||
mem_cache_size = mem_cache ? size : 0;
|
||||
}
|
||||
|
||||
return mem_cache.get();
|
||||
}
|
||||
};
|
||||
|
||||
typedef bool (*compute_func_type)(tensor * dst, compute_params * params);
|
||||
typedef bool (*op_is_supported_func_type)(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1,
|
||||
const npu_device_tensor_spec & dst, npu_device_tensor_op op);
|
||||
|
||||
inline constexpr std::pair<int64_t, int64_t> get_thread_work_slice(int64_t total, size_t tidx, size_t tcnt) {
|
||||
if (total <= 0 || tidx >= tcnt) {
|
||||
return { 0, 0 }; // No work for this thread
|
||||
|
|
@ -72,9 +38,27 @@ inline constexpr std::pair<int64_t, int64_t> get_thread_work_slice(int64_t total
|
|||
return { start, std::min(end, total) };
|
||||
}
|
||||
|
||||
constexpr const size_t kBytesPerVector = sizeof(HVX_Vector); // 128 for v73
|
||||
constexpr const size_t kAlignMask = kBytesPerVector - 1;
|
||||
constexpr const size_t kL2CacheSize = 8 * 1024; // // 8KB L2 cache
|
||||
constexpr const size_t kL2FetchAheadVectors = kL2CacheSize / kBytesPerVector;
|
||||
struct compute_params {
|
||||
default_thread_pool::thread_params * const thread_params;
|
||||
const float * f16_to_f32_table;
|
||||
|
||||
uint8_t * get_vtcm_cache(size_t size) { return thread_params->get_vtcm_cache(size); }
|
||||
|
||||
uint8_t * get_mem_cache(size_t size) { return thread_params->get_mem_cache(size); }
|
||||
|
||||
std::pair<int64_t, int64_t> get_work_slice(int64_t total) const {
|
||||
return get_thread_work_slice(total, thread_params->tidx, thread_params->tcnt);
|
||||
}
|
||||
|
||||
size_t get_vtcm_quota_size() const { return thread_params->vtcm_quota_size; }
|
||||
|
||||
size_t get_thread_count() const { return thread_params->tcnt; }
|
||||
|
||||
size_t get_thread_index() const { return thread_params->tidx; }
|
||||
};
|
||||
|
||||
typedef bool (*compute_func_type)(tensor * dst, compute_params * params);
|
||||
typedef bool (*op_is_supported_func_type)(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
|
||||
const npu_device_tensor_spec * srcs, size_t src_len);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -1,213 +0,0 @@
|
|||
#include "quants.hpp"
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "op_types.hpp" // TODO: remove this include
|
||||
|
||||
static_assert(sizeof(npu_device_block_q4_K) ==
|
||||
2 * sizeof(npu_device_fp16_t) + QUANT_K_SCALE_SIZE + QUANT_K_BLOCK_SIZE / 2,
|
||||
"wrong q4_K block size/padding");
|
||||
|
||||
static_assert(sizeof(npu_device_block_q4_0) == sizeof(npu_device_fp16_t) + QUANT_BLOCK_SIZE / 2,
|
||||
"wrong q4_0 block size/padding");
|
||||
|
||||
static_assert(sizeof(npu_device_block_q8_0) == sizeof(npu_device_fp16_t) + QUANT_BLOCK_SIZE,
|
||||
"wrong q8_0 block size/padding");
|
||||
|
||||
namespace {
|
||||
|
||||
inline HVX_Vector vmemu(const void * unaligned_ptr) {
|
||||
HVX_Vector ret = *reinterpret_cast<const HVX_UVector *>(unaligned_ptr);
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline float to_float(const npu_device_fp16_t src) {
|
||||
return reinterpret_cast<const __fp16 &>(src);
|
||||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector load_block_generic(const _TBlock & src) {
|
||||
uint8_t buffer[hexagon::kBytesPerVector];
|
||||
|
||||
static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding");
|
||||
static_assert(sizeof(buffer) >= sizeof(src.qs), "wrong q4_0 block size/padding");
|
||||
|
||||
memcpy(&buffer[0], src.qs, sizeof(src.qs));
|
||||
return *reinterpret_cast<HVX_UVector *>(buffer);
|
||||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector load_dual_block_generic(const _TBlock & src1, const _TBlock & src2) {
|
||||
uint8_t buffer[hexagon::kBytesPerVector];
|
||||
|
||||
static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding");
|
||||
static_assert(sizeof(buffer) >= sizeof(src1.qs) * 2, "wrong q4_0 block size/padding");
|
||||
|
||||
memcpy(&buffer[0], src1.qs, sizeof(src1.qs));
|
||||
memcpy(&buffer[sizeof(src1.qs)], src2.qs, sizeof(src2.qs));
|
||||
return *reinterpret_cast<HVX_UVector *>(buffer);
|
||||
}
|
||||
|
||||
inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63;
|
||||
*m = q[j + 4] & 63;
|
||||
} else {
|
||||
*d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||
*m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q8_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
|
||||
|
||||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q8_0 *>(src);
|
||||
HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const auto & src = src_ptr[i];
|
||||
HVX_Vector d = Q6_Vh_vsplat_R(src.d);
|
||||
|
||||
HVX_Vector q_lo = load_block_generic(src);
|
||||
HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q));
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q));
|
||||
q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d);
|
||||
out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q));
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
static_assert(qk % 2 == 0, "qk must be even");
|
||||
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
|
||||
|
||||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src);
|
||||
HVX_Vector mask = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector minus = Q6_Vb_vsplat_R(8);
|
||||
HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access
|
||||
|
||||
const int loop_count = nb - (nb % 2);
|
||||
for (int i = 0; i < loop_count; i += 2) {
|
||||
const auto & src1 = src_ptr[i];
|
||||
const auto & src2 = src_ptr[i + 1];
|
||||
|
||||
HVX_Vector d1 = Q6_Vh_vsplat_R(src1.d);
|
||||
HVX_Vector d2 = Q6_Vh_vsplat_R(src2.d);
|
||||
d1 = Q6_V_valign_VVR(d1, Q6_V_vzero(), hexagon::kBytesPerVector / 2);
|
||||
d1 = Q6_V_valign_VVR(d2, d1, hexagon::kBytesPerVector / 2);
|
||||
HVX_Vector d = Q6_Vh_vshuff_Vh(d1);
|
||||
|
||||
HVX_Vector q_lo = load_dual_block_generic(src1, src2);
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4);
|
||||
HVX_VectorPair q = Q6_W_vshuff_VVR(q_hi, Q6_V_vand_VV(q_lo, mask), kSizeOfQs);
|
||||
q_lo = Q6_V_valign_VVR(Q6_V_lo_W(q), Q6_V_vzero(), hexagon::kBytesPerVector / 2);
|
||||
q_lo = Q6_V_valign_VVR(Q6_V_hi_W(q), q_lo, hexagon::kBytesPerVector / 2);
|
||||
q_lo = Q6_Vb_vshuff_Vb(q_lo);
|
||||
q_lo = Q6_Vb_vsub_VbVb(q_lo, minus);
|
||||
q = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q));
|
||||
q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d);
|
||||
out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q));
|
||||
out[i + 1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(q));
|
||||
}
|
||||
|
||||
if (loop_count < nb) {
|
||||
const auto & curr_blk = src_ptr[nb - 1];
|
||||
HVX_Vector d = Q6_Vh_vsplat_R(curr_blk.d);
|
||||
|
||||
HVX_Vector q_lo = load_block_generic(curr_blk);
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4);
|
||||
q_lo = Q6_V_valign_VVR(Q6_V_vand_VV(q_lo, mask), Q6_V_vzero(), sizeof(curr_blk.qs));
|
||||
q_lo = Q6_V_valign_VVR(q_hi, q_lo, hexagon::kBytesPerVector - sizeof(curr_blk.qs));
|
||||
q_lo = Q6_Vb_vsub_VbVb(q_lo, minus);
|
||||
|
||||
HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q));
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q));
|
||||
q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d);
|
||||
out[nb - 1] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q));
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_K(const void * src, float * dst, size_t count, const float * f16_to_f32_table) {
|
||||
const int nb = count / QUANT_K_BLOCK_SIZE;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_K *>(src);
|
||||
|
||||
// TODO: use intrinsics
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * q = src_ptr[i].qs;
|
||||
|
||||
const float d = f16_to_f32_table[src_ptr[i].d];
|
||||
const float min = f16_to_f32_table[src_ptr[i].dmin];
|
||||
|
||||
int is = 0;
|
||||
uint8_t sc = 0;
|
||||
uint8_t m = 0;
|
||||
const auto * scales = src_ptr[i].scales;
|
||||
for (int j = 0; j < QUANT_K_BLOCK_SIZE; j += 64) {
|
||||
get_scale_min_k4(is + 0, scales, &sc, &m);
|
||||
const float d1 = d * sc;
|
||||
const float m1 = min * m;
|
||||
get_scale_min_k4(is + 1, scales, &sc, &m);
|
||||
const float d2 = d * sc;
|
||||
const float m2 = min * m;
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
dst[0] = d1 * (q[l] & 0xF) - m1;
|
||||
dst[32] = d2 * ((q[l] >> 4) & 0xF) - m2;
|
||||
dst++;
|
||||
}
|
||||
dst += 32;
|
||||
q += 32;
|
||||
is += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
|
||||
{ NPU_DATA_TYPE_F32, "F32", 1, false, nullptr },
|
||||
{ NPU_DATA_TYPE_F16, "F16", 1, false, nullptr },
|
||||
{ NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, true, dequantize_row_q8_0 },
|
||||
{ NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, true, dequantize_row_q4_0 },
|
||||
{ NPU_DATA_TYPE_Q4_K, "Q4_K", QUANT_K_BLOCK_SIZE, true, dequantize_row_q4_K },
|
||||
};
|
||||
|
||||
static_assert(std::size(kDeviceTypeTraits) == NPU_DATA_TYPE_COUNT,
|
||||
"kDeviceTypeTraits size mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F32].type == NPU_DATA_TYPE_F32,
|
||||
"kDeviceTypeTraits F32 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F16].type == NPU_DATA_TYPE_F16,
|
||||
"kDeviceTypeTraits F16 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q8_0].type == NPU_DATA_TYPE_Q8_0,
|
||||
"kDeviceTypeTraits Q8_0 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_0].type == NPU_DATA_TYPE_Q4_0,
|
||||
"kDeviceTypeTraits Q4_0 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_K].type == NPU_DATA_TYPE_Q4_K,
|
||||
"kDeviceTypeTraits Q4_K type mismatch with npu_device_tensor_data_type enum");
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool init_f16_f32_table(float * table, size_t count) {
|
||||
constexpr const size_t kTableSize = (1U << 16);
|
||||
if (count < kTableSize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
table[i] = to_float(i);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const device_type_traits & get_type_traits(npu_device_tensor_data_type type) {
|
||||
return kDeviceTypeTraits[type];
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -3,6 +3,8 @@
|
|||
#include <HAP_mem.h>
|
||||
#include <qurt.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
#include "util.hpp"
|
||||
|
||||
|
|
@ -23,7 +25,7 @@ class tensor {
|
|||
}
|
||||
|
||||
_data = static_cast<uint8_t *>(mmap_address);
|
||||
DEVICE_LOG_INFO("tensor(%p[%ldx%ldx%ldx%ld]), fd: %d, offset: %zu, mmap_address: %p, phy_address: 0x%lx\n",
|
||||
DEVICE_LOG_INFO("tensor(%p[%ldx%ldx%ldx%ld]), fd: %d, offset: %zu, mmap_addr: %p, phy_addr: 0x%lx\n",
|
||||
(void *) this, (long) _info.ne[0], (long) _info.ne[1], (long) _info.ne[2], (long) _info.ne[3],
|
||||
_info.buffer_fd, _info.offset, (void *) mmap_address, phy_address);
|
||||
}
|
||||
|
|
@ -47,14 +49,14 @@ class tensor {
|
|||
void invalidate() const {
|
||||
if (_data) {
|
||||
qurt_mem_cache_clean((qurt_addr_t) (_data + _info.offset), (qurt_size_t) _info.size,
|
||||
QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE);
|
||||
QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE);
|
||||
}
|
||||
}
|
||||
|
||||
void update_config(const npu_device_tensor_update_config & config) {
|
||||
static_assert(sizeof(_op_params) == sizeof(config.params), "op params size mismatch");
|
||||
|
||||
_info.op = config.op;
|
||||
_op_type = config.op;
|
||||
memcpy(_op_params, config.params, sizeof(_op_params));
|
||||
for (size_t i = 0; i < DEVICE_TENSOR_MAX_SRC; ++i) {
|
||||
auto src_handle = config.src_handles[i];
|
||||
|
|
@ -76,7 +78,12 @@ class tensor {
|
|||
|
||||
const size_t get_nb(size_t index) const { return _info.nb[index]; }
|
||||
|
||||
npu_device_tensor_op get_op() const { return _info.op; }
|
||||
const bool is_permuted() const {
|
||||
// Check if the tensor is permuted by comparing the nb values
|
||||
return is_transposed_or_permuted(_info.nb);
|
||||
}
|
||||
|
||||
npu_device_tensor_op get_op() const { return _op_type; }
|
||||
|
||||
template <typename _TyParam> const _TyParam get_op_param(size_t index) const {
|
||||
static_assert(sizeof(_TyParam) <= sizeof(_op_params), "_op_param type size exceeds op params size");
|
||||
|
|
@ -95,19 +102,34 @@ class tensor {
|
|||
npu_device_tensor_data_type get_type() const { return _info.type; }
|
||||
|
||||
const uint8_t * get_read_buffer() const {
|
||||
invalidate();
|
||||
if (!_info.is_constant && _has_modified) {
|
||||
invalidate();
|
||||
const_cast<tensor *>(this)->_has_modified = false; // TODO: avoid const_cast
|
||||
}
|
||||
|
||||
return _data + _info.offset;
|
||||
}
|
||||
|
||||
uint8_t * get_write_buffer() const { return _data + _info.offset; }
|
||||
uint8_t * get_write_buffer() const {
|
||||
if (_info.is_constant) {
|
||||
DEVICE_LOG_ERROR("Attempt to write to a constant tensor: %p", (void *) this);
|
||||
return nullptr; // Do not allow writing to constant tensors
|
||||
}
|
||||
|
||||
return _data + _info.offset;
|
||||
}
|
||||
|
||||
void release_write_buffer() { _has_modified = true; }
|
||||
|
||||
bool is_valid() const { return _data != nullptr; }
|
||||
|
||||
private:
|
||||
npu_device_tensor_config _info = {};
|
||||
npu_device_tensor_op _op_type = NPU_OP_COUNT;
|
||||
int32_t _op_params[kMaxParamsCount] = {};
|
||||
tensor * _src[kMaxTensorSrc] = {};
|
||||
uint8_t * _data = nullptr;
|
||||
std::atomic_bool _has_modified = false;
|
||||
|
||||
DISABLE_COPY_AND_MOVE(tensor);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "util.hpp"
|
||||
#include "vtcm_mem.hpp"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
|
|
@ -78,28 +79,65 @@ template <size_t _stack_size> class qurt_thread {
|
|||
|
||||
using qurt_thread_ptr = std::unique_ptr<qurt_thread<kDefaultStackSize>>;
|
||||
|
||||
template <size_t _thread_count> class thread_pool {
|
||||
static_assert(_thread_count > 1, "Thread count must be greater than 1");
|
||||
constexpr const static size_t kMaxSubThreadCount = _thread_count - 1;
|
||||
template <size_t _ThreadCount> class thread_pool {
|
||||
static_assert(_ThreadCount > 1, "Thread count must be greater than 1");
|
||||
constexpr const static size_t kMaxThreadCount = _ThreadCount;
|
||||
constexpr const static size_t kMaxSubThreadCount = _ThreadCount - 1;
|
||||
|
||||
public:
|
||||
typedef qurt_thread<kDefaultStackSize> thread_type;
|
||||
typedef void (*task_type)(thread_pool * pool, size_t thread_idx, size_t thread_count, void * arg);
|
||||
|
||||
struct thread_params {
|
||||
size_t tidx;
|
||||
const size_t tcnt = kMaxThreadCount;
|
||||
thread_pool<kMaxThreadCount> * pool = nullptr;
|
||||
size_t vtcm_quota_size;
|
||||
|
||||
std::unique_ptr<vtcm_mem> vtcm_cache;
|
||||
std::unique_ptr<uint8_t[]> mem_cache;
|
||||
size_t mem_cache_size = 0;
|
||||
|
||||
uint8_t * get_vtcm_cache(size_t size) {
|
||||
if (!vtcm_cache || vtcm_cache->get_size() < size) {
|
||||
DEVICE_SCOPED_PERFORMANCE_TRACKER("[thread_params]get_vtcm_cache, size: %zu, tidx: %zu", size, tidx);
|
||||
vtcm_cache.reset(); // reset the cache to create a new one
|
||||
vtcm_cache = std::make_unique<vtcm_mem>(size, false);
|
||||
}
|
||||
|
||||
if (!vtcm_cache->is_valid()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return vtcm_cache->get_mem();
|
||||
}
|
||||
|
||||
uint8_t * get_mem_cache(size_t size) {
|
||||
if (!mem_cache || mem_cache_size < size) {
|
||||
mem_cache.reset(); // reset the cache to create a new one
|
||||
mem_cache = std::make_unique<uint8_t[]>(size + 256);
|
||||
mem_cache_size = mem_cache ? size : 0;
|
||||
}
|
||||
|
||||
return mem_cache.get();
|
||||
}
|
||||
};
|
||||
|
||||
typedef void (*task_type)(thread_pool * pool, thread_params * param, void * arg);
|
||||
|
||||
thread_pool() {
|
||||
std::string thread_name_base = "thread_pool_";
|
||||
for (size_t i = 0; i < kMaxThreadCount; ++i) {
|
||||
_thread_params[i].tidx = i;
|
||||
_thread_params[i].vtcm_quota_size = hexagon::vtcm_mem::get_avail_block_size() / kMaxThreadCount;
|
||||
_thread_params[i].pool = this;
|
||||
}
|
||||
|
||||
qurt_barrier_init(&_pending, kMaxSubThreadCount + 1);
|
||||
qurt_barrier_init(&_completed, kMaxSubThreadCount + 1);
|
||||
const auto priority = qurt_thread_get_priority(qurt_thread_get_id());
|
||||
const auto priority = qurt_thread_get_priority(qurt_thread_get_id());
|
||||
std::string thread_name_base = "thread_pool_";
|
||||
for (size_t i = 0; i < kMaxSubThreadCount; ++i) {
|
||||
auto & thread_arg = _thread_args[i];
|
||||
thread_arg.pool = this;
|
||||
thread_arg.thread_idx = i + 1;
|
||||
|
||||
auto thread = std::make_unique<thread_type>(
|
||||
thread_name_base + std::to_string(i),
|
||||
reinterpret_cast<thread_type::qurt_thread_func_type>(&thread_pool::thread_func_impl), &thread_arg,
|
||||
priority);
|
||||
thread_name_base + std::to_string(i), &thread_pool::thread_func_impl, &_thread_params[i + 1], priority);
|
||||
if (!thread->is_valid()) {
|
||||
DEVICE_LOG_ERROR("Failed to create thread: %zu", i);
|
||||
// destroy all barriers and threads at destructor
|
||||
|
|
@ -108,6 +146,7 @@ template <size_t _thread_count> class thread_pool {
|
|||
|
||||
_threads[i] = std::move(thread);
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("thread_pool.created: %zu", kMaxSubThreadCount);
|
||||
}
|
||||
|
||||
|
|
@ -130,60 +169,85 @@ template <size_t _thread_count> class thread_pool {
|
|||
return false;
|
||||
}
|
||||
|
||||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
_task_begin_cycles = HAP_perf_get_qtimer_count();
|
||||
#endif
|
||||
|
||||
_task = task;
|
||||
_arg = arg;
|
||||
qurt_barrier_wait(&_pending);
|
||||
|
||||
task(this, 0, kMaxSubThreadCount + 1, arg);
|
||||
task(this, &_thread_params[0], arg);
|
||||
DEVICE_LOG_DEBUG("main_thread.task_completed: 0");
|
||||
|
||||
qurt_barrier_wait(&_completed);
|
||||
|
||||
_task = nullptr;
|
||||
_arg = nullptr;
|
||||
|
||||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
_task_begin_cycles = 0;
|
||||
#endif
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void sync_thread() { qurt_barrier_wait(&_completed); }
|
||||
|
||||
private:
|
||||
struct thread_pool_arg {
|
||||
thread_pool * pool = nullptr;
|
||||
size_t thread_idx = 0;
|
||||
};
|
||||
static size_t get_per_thread_vtcm_quota() { return vtcm_mem::get_total_size() / kMaxThreadCount; }
|
||||
|
||||
static void thread_func_impl(thread_type * thread, thread_pool_arg * arg) {
|
||||
private:
|
||||
static void thread_func_impl(thread_type * thread, void * arg) {
|
||||
NPU_UNUSED(thread);
|
||||
|
||||
DEVICE_LOG_DEBUG("thread_func_impl.start: %zu", arg->thread_idx);
|
||||
auto * param = reinterpret_cast<thread_params *>(arg);
|
||||
|
||||
auto & pool = *arg->pool;
|
||||
DEVICE_LOG_DEBUG("thread_func_impl.start: %zu", param->tidx);
|
||||
|
||||
auto & pool = *(param->pool);
|
||||
for (;;) {
|
||||
qurt_barrier_wait(&pool._pending);
|
||||
if (pool._thread_exit) {
|
||||
DEVICE_LOG_DEBUG("thread_func_impl.exit: %zu", arg->thread_idx);
|
||||
DEVICE_LOG_DEBUG("thread_func_impl.exit: %zu", param->tidx);
|
||||
break;
|
||||
}
|
||||
|
||||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
auto task_begin_cycles = pool._task_begin_cycles.load();
|
||||
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, prepare: %lluus", param->tidx,
|
||||
static_cast<unsigned long long>(
|
||||
HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles)));
|
||||
#endif
|
||||
|
||||
auto task = pool._task;
|
||||
if (task) {
|
||||
task(arg->pool, arg->thread_idx, kMaxSubThreadCount + 1, pool._arg);
|
||||
task(param->pool, param, pool._arg);
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("thread_func_impl.task_completed: %zu", arg->thread_idx);
|
||||
DEVICE_LOG_DEBUG("thread_func_impl.task_completed: %zu", param->tidx);
|
||||
qurt_barrier_wait(&pool._completed);
|
||||
|
||||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
DEVICE_LOG_WARN("[profiler]worker_thread, tidx: %zu, task_end: %lluus", param->tidx,
|
||||
static_cast<unsigned long long>(
|
||||
HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - task_begin_cycles)));
|
||||
#endif
|
||||
}
|
||||
|
||||
DEVICE_LOG_DEBUG("thread_func_impl.end: %zu", arg->thread_idx);
|
||||
DEVICE_LOG_DEBUG("thread_func_impl.end: %zu", param->tidx);
|
||||
}
|
||||
|
||||
std::atomic_bool _thread_exit = false;
|
||||
std::array<qurt_thread_ptr, kMaxSubThreadCount> _threads;
|
||||
thread_pool_arg _thread_args[kMaxSubThreadCount] = {};
|
||||
qurt_barrier_t _pending = {};
|
||||
qurt_barrier_t _completed = {};
|
||||
task_type _task = nullptr;
|
||||
void * _arg = nullptr;
|
||||
qurt_barrier_t _pending = {};
|
||||
qurt_barrier_t _completed = {};
|
||||
thread_params _thread_params[kMaxThreadCount] = {};
|
||||
task_type _task = nullptr;
|
||||
void * _arg = nullptr;
|
||||
|
||||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
std::atomic<uint64_t> _task_begin_cycles = 0;
|
||||
#endif
|
||||
|
||||
DISABLE_COPY_AND_MOVE(thread_pool);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,467 @@
|
|||
#include "type_traits.hpp"
|
||||
|
||||
#include <hexagon_types.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "op_types.hpp" // TODO: remove this include
|
||||
#include "vec_ops.hpp"
|
||||
|
||||
static_assert(sizeof(npu_device_block_q4_k) ==
|
||||
2 * sizeof(npu_device_fp16_t) + QUANT_K_SCALE_SIZE + QUANT_K_BLOCK_SIZE / 2,
|
||||
"wrong q4_K block size/padding");
|
||||
|
||||
static_assert(sizeof(npu_device_block_q4_0) == sizeof(npu_device_fp16_t) + QUANT_BLOCK_SIZE / 2,
|
||||
"wrong q4_0 block size/padding");
|
||||
|
||||
static_assert(sizeof(npu_device_block_q8_0) == sizeof(npu_device_fp16_t) + QUANT_BLOCK_SIZE,
|
||||
"wrong q8_0 block size/padding");
|
||||
|
||||
namespace {
|
||||
|
||||
inline float to_float(const npu_device_fp16_t src) {
|
||||
return reinterpret_cast<const __fp16 &>(src);
|
||||
}
|
||||
|
||||
inline npu_device_fp16_t to_fp16(const float src) {
|
||||
__fp16 f16_value = static_cast<__fp16>(src);
|
||||
return reinterpret_cast<const npu_device_fp16_t &>(f16_value);
|
||||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector load_block_generic(const _TBlock & src) {
|
||||
uint8_t buffer[hexagon::kBytesPerVector];
|
||||
|
||||
static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding");
|
||||
static_assert(sizeof(buffer) >= sizeof(src.qs), "wrong q4_0 block size/padding");
|
||||
|
||||
memcpy(&buffer[0], src.qs, sizeof(src.qs));
|
||||
return *reinterpret_cast<HVX_UVector *>(buffer);
|
||||
}
|
||||
|
||||
template <typename _TBlock> inline HVX_Vector load_dual_block_generic(const _TBlock & src1, const _TBlock & src2) {
|
||||
uint8_t buffer[hexagon::kBytesPerVector];
|
||||
|
||||
static_assert(sizeof(buffer) == sizeof(HVX_Vector), "wrong cvt size/padding");
|
||||
static_assert(sizeof(buffer) >= sizeof(src1.qs) * 2, "wrong q4_0 block size/padding");
|
||||
|
||||
memcpy(&buffer[0], src1.qs, sizeof(src1.qs));
|
||||
memcpy(&buffer[sizeof(src1.qs)], src2.qs, sizeof(src2.qs));
|
||||
return *reinterpret_cast<HVX_UVector *>(buffer);
|
||||
}
|
||||
|
||||
inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63;
|
||||
*m = q[j + 4] & 63;
|
||||
} else {
|
||||
*d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||
*m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
inline int nearest_int(float fval) {
|
||||
float val = fval + 12582912.f;
|
||||
int i = reinterpret_cast<const int &>(val);
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
}
|
||||
|
||||
float make_qkx2_quants(int n, int nmax, const float * x, const float * weights, uint8_t * L, float * the_min,
|
||||
uint8_t * Laux, float rmin, float rdelta, int nstep, bool use_mad) {
|
||||
float min = x[0];
|
||||
float max = x[0];
|
||||
float sum_w = weights[0];
|
||||
float sum_x = sum_w * x[0];
|
||||
for (int i = 1; i < n; ++i) {
|
||||
if (x[i] < min) {
|
||||
min = x[i];
|
||||
}
|
||||
if (x[i] > max) {
|
||||
max = x[i];
|
||||
}
|
||||
float w = weights[i];
|
||||
sum_w += w;
|
||||
sum_x += w * x[i];
|
||||
}
|
||||
if (min > 0) {
|
||||
min = 0;
|
||||
}
|
||||
if (max == min) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
L[i] = 0;
|
||||
}
|
||||
*the_min = -min;
|
||||
return 0.f;
|
||||
}
|
||||
float iscale = nmax / (max - min);
|
||||
float scale = 1 / iscale;
|
||||
float best_mad = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int l = nearest_int(iscale * (x[i] - min));
|
||||
L[i] = std::max<int>(0, std::min(nmax, l));
|
||||
float diff = scale * L[i] + min - x[i];
|
||||
diff = use_mad ? fabsf(diff) : diff * diff;
|
||||
float w = weights[i];
|
||||
best_mad += w * diff;
|
||||
}
|
||||
if (nstep < 1) {
|
||||
*the_min = -min;
|
||||
return scale;
|
||||
}
|
||||
for (int is = 0; is <= nstep; ++is) {
|
||||
iscale = (rmin + rdelta * is + nmax) / (max - min);
|
||||
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int l = nearest_int(iscale * (x[i] - min));
|
||||
l = std::max<int>(0, std::min(nmax, l));
|
||||
Laux[i] = l;
|
||||
float w = weights[i];
|
||||
sum_l += w * l;
|
||||
sum_l2 += w * l * l;
|
||||
sum_xl += w * l * x[i];
|
||||
}
|
||||
float D = sum_w * sum_l2 - sum_l * sum_l;
|
||||
if (D > 0) {
|
||||
float this_scale = (sum_w * sum_xl - sum_x * sum_l) / D;
|
||||
float this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D;
|
||||
if (this_min > 0) {
|
||||
this_min = 0;
|
||||
this_scale = sum_xl / sum_l2;
|
||||
}
|
||||
float mad = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
float diff = this_scale * Laux[i] + this_min - x[i];
|
||||
diff = use_mad ? fabsf(diff) : diff * diff;
|
||||
float w = weights[i];
|
||||
mad += w * diff;
|
||||
}
|
||||
if (mad < best_mad) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
L[i] = Laux[i];
|
||||
}
|
||||
best_mad = mad;
|
||||
scale = this_scale;
|
||||
min = this_min;
|
||||
}
|
||||
}
|
||||
}
|
||||
*the_min = -min;
|
||||
return scale;
|
||||
}
|
||||
|
||||
void quantize_row_fp16(const float * src, void * dst, size_t count, const float * f16_to_f32_table) {
|
||||
auto * out = reinterpret_cast<npu_device_fp16_t *>(dst);
|
||||
// TODO: use hvx intrinsics for better performance
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
out[i] = to_fp16(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_q8_0(const float * src, void * dst, size_t count, const float * f16_to_f32_table) {
|
||||
const int nb = count / QUANT_BLOCK_SIZE;
|
||||
auto * out = reinterpret_cast<npu_device_block_q8_0 *>(dst);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QUANT_BLOCK_SIZE; j++) {
|
||||
const float v = src[i * QUANT_BLOCK_SIZE + j];
|
||||
amax = std::max(amax, fabsf(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f / d : 0.0f;
|
||||
|
||||
out[i].d = to_fp16(d);
|
||||
|
||||
for (int j = 0; j < QUANT_BLOCK_SIZE; ++j) {
|
||||
const float x0 = src[i * QUANT_BLOCK_SIZE + j] * id;
|
||||
|
||||
out[i].qs[j] = roundf(x0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_q4_0(const float * src, void * dst, size_t count, const float * f16_to_f32_table) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
|
||||
const int nb = count / qk;
|
||||
auto * out = reinterpret_cast<npu_device_block_q4_0 *>(dst);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < qk; j++) {
|
||||
const float v = src[i * qk + j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -8;
|
||||
const float id = d ? 1.0f / d : 0.0f;
|
||||
|
||||
out[i].d = to_fp16(d);
|
||||
|
||||
for (int j = 0; j < qk / 2; ++j) {
|
||||
const float x0 = src[i * qk + 0 + j] * id;
|
||||
const float x1 = src[i * qk + qk / 2 + j] * id;
|
||||
|
||||
const uint8_t xi0 = std::min<int8_t>(15, (x0 + 8.5f));
|
||||
const uint8_t xi1 = std::min<int8_t>(15, (x1 + 8.5f));
|
||||
|
||||
out[i].qs[j] = xi0;
|
||||
out[i].qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_q4_K(const float * src, void * dst, size_t count, const float * f16_to_f32_table) {
|
||||
const int nb = count / QUANT_K_BLOCK_SIZE;
|
||||
auto * out = reinterpret_cast<npu_device_block_q4_k *>(dst);
|
||||
|
||||
uint8_t L[QUANT_K_BLOCK_SIZE];
|
||||
uint8_t Laux[32];
|
||||
float weights[32];
|
||||
float mins[QUANT_K_BLOCK_SIZE / 32];
|
||||
float scales[QUANT_K_BLOCK_SIZE / 32];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float max_scale = 0; // as we are deducting the min, scales are always positive
|
||||
float max_min = 0;
|
||||
for (int j = 0; j < QUANT_K_BLOCK_SIZE / 32; ++j) {
|
||||
//scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
|
||||
float sum_x2 = 0;
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
sum_x2 += src[32 * j + l] * src[32 * j + l];
|
||||
}
|
||||
float av_x = sqrtf(sum_x2 / 32);
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
weights[l] = av_x + fabsf(src[32 * j + l]);
|
||||
}
|
||||
scales[j] =
|
||||
make_qkx2_quants(32, 15, src + 32 * j, weights, L + 32 * j, &mins[j], Laux, -1.f, 0.1f, 20, false);
|
||||
float scale = scales[j];
|
||||
if (scale > max_scale) {
|
||||
max_scale = scale;
|
||||
}
|
||||
float min = mins[j];
|
||||
if (min > max_min) {
|
||||
max_min = min;
|
||||
}
|
||||
}
|
||||
|
||||
float inv_scale = max_scale > 0 ? 63.f / max_scale : 0.f;
|
||||
float inv_min = max_min > 0 ? 63.f / max_min : 0.f;
|
||||
for (int j = 0; j < QUANT_K_BLOCK_SIZE / 32; ++j) {
|
||||
uint8_t ls = nearest_int(inv_scale * scales[j]);
|
||||
uint8_t lm = nearest_int(inv_min * mins[j]);
|
||||
ls = std::min<uint8_t>(63, ls);
|
||||
lm = std::min<uint8_t>(63, lm);
|
||||
if (j < 4) {
|
||||
out[i].scales[j] = ls;
|
||||
out[i].scales[j + 4] = lm;
|
||||
} else {
|
||||
out[i].scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4);
|
||||
out[i].scales[j - 4] |= ((ls >> 4) << 6);
|
||||
out[i].scales[j - 0] |= ((lm >> 4) << 6);
|
||||
}
|
||||
}
|
||||
out[i].d = to_fp16(max_scale / 63.f);
|
||||
out[i].dmin = to_fp16(max_min / 63.f);
|
||||
|
||||
uint8_t sc, m;
|
||||
for (int j = 0; j < QUANT_K_BLOCK_SIZE / 32; ++j) {
|
||||
get_scale_min_k4(j, out[i].scales, &sc, &m);
|
||||
const float d = f16_to_f32_table[out[i].d] * sc;
|
||||
if (!d) {
|
||||
continue;
|
||||
}
|
||||
const float dm = f16_to_f32_table[out[i].dmin] * m;
|
||||
for (int ii = 0; ii < 32; ++ii) {
|
||||
int l = nearest_int((src[32 * j + ii] + dm) / d);
|
||||
l = std::max<int>(0, std::min<int>(15, l));
|
||||
L[32 * j + ii] = l;
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t * q = out[i].qs;
|
||||
for (int j = 0; j < QUANT_K_BLOCK_SIZE; j += 64) {
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
q[l] = L[j + l] | (L[j + l + 32] << 4);
|
||||
}
|
||||
q += 32;
|
||||
}
|
||||
|
||||
src += QUANT_K_BLOCK_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q8_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
|
||||
|
||||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q8_0 *>(src);
|
||||
HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const auto & src = src_ptr[i];
|
||||
HVX_Vector d = Q6_Vh_vsplat_R(src.d);
|
||||
|
||||
HVX_Vector q_lo = load_block_generic(src);
|
||||
HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q));
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q));
|
||||
q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d);
|
||||
out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q));
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_0(const void * src, float * dst, size_t count, const float * f16_to_f32_table) {
|
||||
constexpr const int qk = QUANT_BLOCK_SIZE;
|
||||
static_assert(qk % 2 == 0, "qk must be even");
|
||||
static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float));
|
||||
constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs);
|
||||
|
||||
const int nb = count / qk;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_0 *>(src);
|
||||
HVX_Vector mask = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector minus = Q6_Vb_vsplat_R(8);
|
||||
HVX_UVector * out = ((HVX_UVector *) dst); // TODO: opt for aligned access
|
||||
|
||||
const int loop_count = nb - (nb % 2);
|
||||
for (int i = 0; i < loop_count; i += 2) {
|
||||
const auto & src1 = src_ptr[i];
|
||||
const auto & src2 = src_ptr[i + 1];
|
||||
|
||||
HVX_Vector d1 = Q6_Vh_vsplat_R(src1.d);
|
||||
HVX_Vector d2 = Q6_Vh_vsplat_R(src2.d);
|
||||
HVX_Vector d = Q6_Vh_vshuff_Vh(Q6_V_valign_VVR(d2, d1, hexagon::kBytesPerVector / 2));
|
||||
|
||||
HVX_Vector q_lo = load_dual_block_generic(src1, src2);
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4);
|
||||
HVX_VectorPair q = Q6_W_vshuff_VVR(q_hi, Q6_V_vand_VV(q_lo, mask), kSizeOfQs);
|
||||
q_lo = Q6_V_valign_VVR(Q6_V_lo_W(q), Q6_V_vzero(), hexagon::kBytesPerVector / 2);
|
||||
q_lo = Q6_V_valign_VVR(Q6_V_hi_W(q), q_lo, hexagon::kBytesPerVector / 2);
|
||||
q_lo = Q6_Vb_vshuff_Vb(q_lo);
|
||||
q_lo = Q6_Vb_vsub_VbVb(q_lo, minus);
|
||||
q = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q));
|
||||
q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d);
|
||||
out[i] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q));
|
||||
out[i + 1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(q));
|
||||
}
|
||||
|
||||
if (loop_count < nb) {
|
||||
const auto & curr_blk = src_ptr[nb - 1];
|
||||
HVX_Vector d = Q6_Vh_vsplat_R(curr_blk.d);
|
||||
|
||||
HVX_Vector q_lo = load_block_generic(curr_blk);
|
||||
HVX_Vector q_hi = Q6_Vub_vlsr_VubR(q_lo, 4);
|
||||
q_lo = Q6_V_valign_VVR(Q6_V_vand_VV(q_lo, mask), Q6_V_vzero(), sizeof(curr_blk.qs));
|
||||
q_lo = Q6_V_valign_VVR(q_hi, q_lo, hexagon::kBytesPerVector - sizeof(curr_blk.qs));
|
||||
q_lo = Q6_Vb_vsub_VbVb(q_lo, minus);
|
||||
|
||||
HVX_VectorPair q = Q6_Wh_vunpack_Vb(q_lo);
|
||||
q = Q6_Wh_vunpack_Vb(Q6_V_lo_W(q));
|
||||
q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(q));
|
||||
q = Q6_Wqf32_vmpy_VhfVhf(q_lo, d);
|
||||
out[nb - 1] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(q));
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_K(const void * src, float * dst, size_t count, const float * f16_to_f32_table) {
|
||||
const int nb = count / QUANT_K_BLOCK_SIZE;
|
||||
const auto * src_ptr = reinterpret_cast<const npu_device_block_q4_k *>(src);
|
||||
|
||||
// TODO: use intrinsics
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * q = src_ptr[i].qs;
|
||||
|
||||
const float d = f16_to_f32_table[src_ptr[i].d];
|
||||
const float min = f16_to_f32_table[src_ptr[i].dmin];
|
||||
|
||||
int is = 0;
|
||||
uint8_t sc = 0;
|
||||
uint8_t m = 0;
|
||||
const auto * scales = src_ptr[i].scales;
|
||||
for (int j = 0; j < QUANT_K_BLOCK_SIZE; j += 64) {
|
||||
get_scale_min_k4(is + 0, scales, &sc, &m);
|
||||
const float d1 = d * sc;
|
||||
const float m1 = min * m;
|
||||
get_scale_min_k4(is + 1, scales, &sc, &m);
|
||||
const float d2 = d * sc;
|
||||
const float m2 = min * m;
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
dst[0] = d1 * (q[l] & 0xF) - m1;
|
||||
dst[32] = d2 * ((q[l] >> 4) & 0xF) - m2;
|
||||
dst++;
|
||||
}
|
||||
dst += 32;
|
||||
q += 32;
|
||||
is += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename _TFunc> struct dot_func_traits {};
|
||||
|
||||
template <typename _TData> struct dot_func_traits<float (*)(_TData, _TData, size_t)> {
|
||||
using param_type = std::remove_const_t<std::remove_pointer_t<_TData>>;
|
||||
};
|
||||
|
||||
template <auto _Func> float wrap_dot_func(const void * src0, const void * src1, size_t count) {
|
||||
using param_type = typename dot_func_traits<decltype(_Func)>::param_type;
|
||||
return _Func(reinterpret_cast<const param_type *>(src0), reinterpret_cast<const param_type *>(src1), count);
|
||||
}
|
||||
|
||||
constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = {
|
||||
{ NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, nullptr, nullptr,
|
||||
wrap_dot_func<hexagon::vec_dot_product_f32_f32> },
|
||||
{ NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, nullptr, quantize_row_fp16,
|
||||
wrap_dot_func<hexagon::vec_dot_product_f16_f16> },
|
||||
{ NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0,
|
||||
quantize_row_q8_0 },
|
||||
{ NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0,
|
||||
quantize_row_q4_0 },
|
||||
{ NPU_DATA_TYPE_Q4_K, "Q4_K", QUANT_K_BLOCK_SIZE, sizeof(npu_device_block_q4_k), true, dequantize_row_q4_K,
|
||||
quantize_row_q4_K },
|
||||
};
|
||||
|
||||
static_assert(std::size(kDeviceTypeTraits) == NPU_DATA_TYPE_COUNT,
|
||||
"kDeviceTypeTraits size mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F32].type == NPU_DATA_TYPE_F32,
|
||||
"kDeviceTypeTraits F32 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_F16].type == NPU_DATA_TYPE_F16,
|
||||
"kDeviceTypeTraits F16 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q8_0].type == NPU_DATA_TYPE_Q8_0,
|
||||
"kDeviceTypeTraits Q8_0 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_0].type == NPU_DATA_TYPE_Q4_0,
|
||||
"kDeviceTypeTraits Q4_0 type mismatch with npu_device_tensor_data_type enum");
|
||||
static_assert(kDeviceTypeTraits[NPU_DATA_TYPE_Q4_K].type == NPU_DATA_TYPE_Q4_K,
|
||||
"kDeviceTypeTraits Q4_K type mismatch with npu_device_tensor_data_type enum");
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
bool init_f16_f32_table(float * table, size_t count) {
|
||||
constexpr const size_t kTableSize = (1U << 16);
|
||||
if (count < kTableSize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
table[i] = to_float(i);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const device_type_traits & get_type_traits(npu_device_tensor_data_type type) {
|
||||
return kDeviceTypeTraits[type];
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -7,14 +7,20 @@ namespace hexagon {
|
|||
|
||||
bool init_f16_f32_table(float * table, size_t count);
|
||||
|
||||
typedef void (*quantize_row_type)(const float * src, void * dst, size_t count, const float * f16_to_f32_table);
|
||||
typedef void (*dequantize_row_type)(const void * src, float * dst, size_t count, const float * f16_to_f32_table);
|
||||
typedef float (*vec_dot_type)(const void * src0, const void * src1, size_t count);
|
||||
|
||||
struct device_type_traits {
|
||||
npu_device_tensor_data_type type;
|
||||
const char * type_name;
|
||||
int64_t blck_size;
|
||||
size_t type_size;
|
||||
bool is_quantized;
|
||||
dequantize_row_type dequantize_row;
|
||||
|
||||
dequantize_row_type to_float;
|
||||
quantize_row_type from_float;
|
||||
vec_dot_type vec_dot;
|
||||
};
|
||||
|
||||
const device_type_traits & get_type_traits(npu_device_tensor_data_type type);
|
||||
|
|
@ -44,10 +50,10 @@ inline const char * get_type_name(npu_device_tensor_data_type type) {
|
|||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
namespace hexagon {
|
||||
|
||||
inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx, const char * sub_proc_log_prefix = nullptr) {
|
||||
inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx) {
|
||||
auto * src0 = op->get_src(0);
|
||||
auto * src1 = op->get_src(1);
|
||||
char buffer[512];
|
||||
char buffer[1024];
|
||||
if (src1 == nullptr) {
|
||||
snprintf(buffer, sizeof(buffer), "[%s][%lldx%lldx%lldx%lld%s], tidx: %zu", op_get_name(op->get_op()),
|
||||
src0->get_ne(0), src0->get_ne(1), src0->get_ne(2), src0->get_ne(3), get_type_name(src0->get_type()),
|
||||
|
|
@ -58,7 +64,7 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx, const char * sub
|
|||
get_type_name(src0->get_type()), src1->get_ne(0), src1->get_ne(1), src1->get_ne(2), src1->get_ne(3),
|
||||
get_type_name(src1->get_type()), tidx);
|
||||
}
|
||||
return npu_scoped_timer<512>(buffer, sub_proc_log_prefix);
|
||||
return npu_scoped_timer<1024>(buffer);
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -67,14 +73,23 @@ inline auto make_scoped_op_perf_timer(tensor * op, size_t tidx, const char * sub
|
|||
auto __npu_op_timer_##__LINE__ = hexagon::make_scoped_op_perf_timer(op, tidx)
|
||||
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(op, tidx, sub_prefix) \
|
||||
auto __npu_op_timer_##sub_prefix = hexagon::make_scoped_op_perf_timer(op, tidx, #sub_prefix)
|
||||
auto __npu_op_timer_##sub_prefix = hexagon::make_scoped_op_perf_timer(op, tidx)
|
||||
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(sub_prefix) \
|
||||
hexagon::npu_sub_process_scoped_timer<decltype(__npu_op_timer_##sub_prefix)::kBufferCount> \
|
||||
__npu_op_sub_timer##sub_prefix(__npu_op_timer_##sub_prefix)
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(sub_prefix) \
|
||||
hexagon::npu_sub_process_scoped_timer<decltype(__npu_op_timer_##sub_prefix)::kBufferCount, 0> \
|
||||
__npu_op_sub_timer##sub_prefix(__npu_op_timer_##sub_prefix, #sub_prefix)
|
||||
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(op, tidx, tracker_name) \
|
||||
auto __npu_op_timer_##tracker_name = hexagon::make_scoped_op_perf_timer(op, tidx)
|
||||
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(tracker_name, idx, sub_prefix) \
|
||||
hexagon::npu_sub_process_scoped_timer<decltype(__npu_op_timer_##tracker_name)::kBufferCount, idx> \
|
||||
__npu_op_sub_timer##sub_prefix(__npu_op_timer_##tracker_name, #sub_prefix)
|
||||
|
||||
#else
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(op, tidx) ((void) 0)
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(op, tidx, sub_prefix) ((void) 0)
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(sub_prefix) ((void) 0)
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(op, tidx) ((void) 0)
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_SUB_PROC(op, tidx, sub_prefix) ((void) 0)
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_SUB_PROC(sub_prefix) ((void) 0)
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(op, tidx, tracker_name) ((void) 0)
|
||||
# define DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(tracker_name, idx, sub_prefix) ((void) 0)
|
||||
#endif
|
||||
|
|
@ -52,11 +52,18 @@ inline constexpr const char * op_get_name(npu_device_tensor_op op) {
|
|||
return "MUL";
|
||||
case NPU_OP_RMS_NORM:
|
||||
return "RMS_NORM";
|
||||
case NPU_OP_FLASH_ATTN:
|
||||
return "FLASH_ATTN_EXT";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_transposed_or_permuted(const npu_device_nb_type & nb) {
|
||||
// Check if the tensor is transposed or permuted
|
||||
return (nb[0] > nb[1]) || (nb[1] > nb[2]) || (nb[2] > nb[3]);
|
||||
}
|
||||
|
||||
class power_utils {
|
||||
public:
|
||||
power_utils() {
|
||||
|
|
@ -160,16 +167,22 @@ class power_utils {
|
|||
|
||||
#ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING
|
||||
|
||||
struct sub_process_data {
|
||||
char log_prefix[32] = {};
|
||||
uint64_t proc_cycles = 0;
|
||||
uint64_t proc_pcycles = 0;
|
||||
uint64_t proc_count = 0;
|
||||
};
|
||||
|
||||
template <size_t _buffer_count> class npu_scoped_timer {
|
||||
public:
|
||||
enum { kBufferCount = _buffer_count };
|
||||
enum {
|
||||
kBufferCount = _buffer_count,
|
||||
kSubProcCount = 4,
|
||||
};
|
||||
|
||||
explicit npu_scoped_timer(const char * log_prefix, const char * sub_proc_log_prefix) {
|
||||
explicit npu_scoped_timer(const char * log_prefix) {
|
||||
strncpy(_log_prefix, log_prefix, kBufferCount - 1);
|
||||
if (sub_proc_log_prefix != nullptr) {
|
||||
strncpy(_sub_proc_log_prefix, sub_proc_log_prefix, kBufferCount - 1);
|
||||
}
|
||||
|
||||
_begin_cycles = HAP_perf_get_qtimer_count();
|
||||
_begin_pcycles = HAP_perf_get_pcycles();
|
||||
}
|
||||
|
|
@ -180,61 +193,121 @@ template <size_t _buffer_count> class npu_scoped_timer {
|
|||
|
||||
void operator=(npu_scoped_timer && other) {
|
||||
strncpy(_log_prefix, other._log_prefix, kBufferCount - 1);
|
||||
strncpy(_sub_proc_log_prefix, other._sub_proc_log_prefix, kBufferCount - 1);
|
||||
_begin_cycles = other._begin_cycles;
|
||||
_sub_proc_cycles = other._sub_proc_cycles;
|
||||
_sub_proc_count = other._sub_proc_count;
|
||||
_begin_cycles = other._begin_cycles;
|
||||
_begin_pcycles = other._begin_pcycles;
|
||||
memcpy(&_sub_proc_data, &other._sub_proc_data, sizeof(_sub_proc_data));
|
||||
}
|
||||
|
||||
void add_sub_proc_cycles(uint64_t cycles, uint64_t pcycles) {
|
||||
_sub_proc_cycles += cycles;
|
||||
_sub_proc_pcycles += pcycles;
|
||||
_sub_proc_count++;
|
||||
void add_sub_proc_cycles(size_t sub_proc_idx, const char * sub_proc_prefix, uint64_t cycles, uint64_t pcycles) {
|
||||
auto & sub_proc_data = _sub_proc_data[sub_proc_idx];
|
||||
sub_proc_data.proc_cycles += cycles;
|
||||
sub_proc_data.proc_pcycles += pcycles;
|
||||
|
||||
if (!sub_proc_data.proc_count) {
|
||||
strncpy(sub_proc_data.log_prefix, sub_proc_prefix, sizeof(sub_proc_data.log_prefix) - 1);
|
||||
}
|
||||
|
||||
sub_proc_data.proc_count++;
|
||||
}
|
||||
|
||||
void print() const {
|
||||
static_assert(kSubProcCount == 4, "Sub process count must be 4 for logging format");
|
||||
|
||||
auto total_cycles = HAP_perf_get_qtimer_count() - _begin_cycles;
|
||||
auto total_pcycles = HAP_perf_get_pcycles() - _begin_pcycles;
|
||||
auto duration = HAP_perf_qtimer_count_to_us(total_cycles);
|
||||
|
||||
if (_sub_proc_count > 0) {
|
||||
auto sub_proc_duration = HAP_perf_qtimer_count_to_us(_sub_proc_cycles);
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, pcyc: %llu, dur: %lluus\n",
|
||||
_log_prefix, total_pcycles, duration, _sub_proc_log_prefix, _sub_proc_count,
|
||||
_sub_proc_pcycles, sub_proc_duration);
|
||||
} else {
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", _log_prefix, total_pcycles, duration);
|
||||
int sub_proc_count = 0;
|
||||
for (int i = kSubProcCount; i > 0; --i) {
|
||||
if (_sub_proc_data[i - 1].proc_count > 0) {
|
||||
sub_proc_count = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto sub_proc0_duration = HAP_perf_qtimer_count_to_us(_sub_proc_data[0].proc_cycles);
|
||||
auto sub_proc1_duration = HAP_perf_qtimer_count_to_us(_sub_proc_data[1].proc_cycles);
|
||||
auto sub_proc2_duration = HAP_perf_qtimer_count_to_us(_sub_proc_data[2].proc_cycles);
|
||||
auto sub_proc3_duration = HAP_perf_qtimer_count_to_us(_sub_proc_data[3].proc_cycles);
|
||||
|
||||
switch (sub_proc_count) {
|
||||
case 4:
|
||||
DEVICE_LOG_WARN(
|
||||
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
|
||||
(unsigned long long) sub_proc2_duration, _sub_proc_data[3].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[3].proc_count, (unsigned long long) sub_proc3_duration);
|
||||
break;
|
||||
case 3:
|
||||
DEVICE_LOG_WARN(
|
||||
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration,
|
||||
_sub_proc_data[2].log_prefix, (unsigned long long) _sub_proc_data[2].proc_count,
|
||||
(unsigned long long) sub_proc2_duration);
|
||||
break;
|
||||
case 2:
|
||||
DEVICE_LOG_WARN(
|
||||
"[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus, "
|
||||
"[%s]cnt: %llu, dur: %lluus\n",
|
||||
_log_prefix, (unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration, _sub_proc_data[1].log_prefix,
|
||||
(unsigned long long) _sub_proc_data[1].proc_count, (unsigned long long) sub_proc1_duration);
|
||||
break;
|
||||
case 1:
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus, [%s]cnt: %llu, dur: %lluus\n", _log_prefix,
|
||||
(unsigned long long) total_pcycles, (unsigned long long) duration,
|
||||
_sub_proc_data[0].log_prefix, (unsigned long long) _sub_proc_data[0].proc_count,
|
||||
(unsigned long long) sub_proc0_duration);
|
||||
break;
|
||||
default:
|
||||
case 0:
|
||||
DEVICE_LOG_WARN("[profiler]%s, pcyc: %llu, dur: %lluus\n", _log_prefix,
|
||||
(unsigned long long) total_pcycles, (unsigned long long) duration);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
char _log_prefix[kBufferCount] = {};
|
||||
char _sub_proc_log_prefix[kBufferCount] = {};
|
||||
uint64_t _begin_cycles = 0;
|
||||
uint64_t _begin_pcycles = 0;
|
||||
uint64_t _sub_proc_cycles = 0;
|
||||
uint64_t _sub_proc_pcycles = 0;
|
||||
uint64_t _sub_proc_count = 0;
|
||||
char _log_prefix[kBufferCount] = {};
|
||||
uint64_t _begin_cycles = 0;
|
||||
uint64_t _begin_pcycles = 0;
|
||||
sub_process_data _sub_proc_data[kSubProcCount] = {};
|
||||
|
||||
DISABLE_COPY(npu_scoped_timer);
|
||||
};
|
||||
|
||||
template <size_t _buffer_count> class npu_sub_process_scoped_timer {
|
||||
template <size_t _buffer_count, size_t _sub_idx> class npu_sub_process_scoped_timer {
|
||||
public:
|
||||
static_assert(_sub_idx < npu_scoped_timer<_buffer_count>::kSubProcCount,
|
||||
"Sub process index must be less than kSubProcCount");
|
||||
using npu_scoped_timer = npu_scoped_timer<_buffer_count>;
|
||||
|
||||
explicit npu_sub_process_scoped_timer(npu_scoped_timer & timer) : _timer(timer) {
|
||||
explicit npu_sub_process_scoped_timer(npu_scoped_timer & timer, const char * prefix) :
|
||||
_timer(timer),
|
||||
_prefix(prefix) {
|
||||
_begin_cycles = HAP_perf_get_qtimer_count();
|
||||
_begin_pcycles = HAP_perf_get_pcycles();
|
||||
}
|
||||
|
||||
~npu_sub_process_scoped_timer() {
|
||||
_timer.add_sub_proc_cycles(HAP_perf_get_qtimer_count() - _begin_cycles,
|
||||
_timer.add_sub_proc_cycles(_sub_idx, _prefix, HAP_perf_get_qtimer_count() - _begin_cycles,
|
||||
HAP_perf_get_pcycles() - _begin_pcycles);
|
||||
}
|
||||
|
||||
private:
|
||||
npu_scoped_timer & _timer;
|
||||
const char * _prefix = nullptr;
|
||||
uint64_t _begin_cycles = 0;
|
||||
uint64_t _begin_pcycles = 0;
|
||||
|
||||
|
|
@ -244,10 +317,10 @@ template <size_t _buffer_count> class npu_sub_process_scoped_timer {
|
|||
inline auto make_scoped_perf_timer(const char * format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
char buffer[512];
|
||||
char buffer[1024];
|
||||
vsnprintf(buffer, sizeof(buffer), format, args);
|
||||
va_end(args);
|
||||
return npu_scoped_timer<512>(buffer, nullptr);
|
||||
return npu_scoped_timer<1024>(buffer);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
#include "vec_ops.hpp"
|
||||
|
||||
#include <HTP/core/intrinsics.h>
|
||||
|
||||
#include "util.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename _TElem, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), float (*_ReduceFunc)(HVX_Vector)>
|
||||
inline float vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector prev0 = *src0_vec_ptr++;
|
||||
HVX_Vector prev1 = *src1_vec_ptr++;
|
||||
HVX_Vector sum = Q6_V_vzero();
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_VectorPair curr0 = reinterpret_cast<HVX_VectorPair *>(src0_vec_ptr)[0];
|
||||
HVX_VectorPair curr1 = reinterpret_cast<HVX_VectorPair *>(src1_vec_ptr)[0];
|
||||
|
||||
HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0);
|
||||
HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1);
|
||||
HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0);
|
||||
HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1);
|
||||
prev0 = Q6_V_hi_W(curr0);
|
||||
prev1 = Q6_V_hi_W(curr1);
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(l0, l1), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(h0, h1), sum1);
|
||||
}
|
||||
|
||||
sum = _AddFunc(sum0, sum1);
|
||||
if (src0_vec_ptr_end - src0_vec_ptr > 0) {
|
||||
HVX_Vector curr0 = *src0_vec_ptr++;
|
||||
HVX_Vector curr1 = *src1_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
||||
sum = _AddFunc(_MpyFunc(s0, s1), sum);
|
||||
}
|
||||
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) {
|
||||
// handle the last vector
|
||||
// see also:
|
||||
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
|
||||
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
|
||||
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr);
|
||||
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr);
|
||||
HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0;
|
||||
HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1;
|
||||
src0_vec_ptr += should_fetch_src0 ? 1 : 0;
|
||||
src1_vec_ptr += should_fetch_src1 ? 1 : 0;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
prev0 = curr0;
|
||||
prev1 = curr1;
|
||||
|
||||
sum = _AddFunc(_MpyFunc(s0, s1), sum);
|
||||
}
|
||||
|
||||
const size_t leftover_bytes = leftover * sizeof(_TElem);
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src0_vec_ptr :
|
||||
prev0;
|
||||
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
|
||||
|
||||
HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ?
|
||||
*src1_vec_ptr :
|
||||
prev1;
|
||||
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
|
||||
|
||||
sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes), sum);
|
||||
}
|
||||
|
||||
return _ReduceFunc(sum);
|
||||
}
|
||||
|
||||
template <typename _TElem, HVX_Vector (*_MpyFunc)(HVX_Vector, HVX_Vector),
|
||||
HVX_Vector (*_AddFunc)(HVX_Vector, HVX_Vector), float (*_ReduceFunc)(HVX_Vector)>
|
||||
inline float vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem);
|
||||
|
||||
HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0);
|
||||
HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector;
|
||||
HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1);
|
||||
HVX_Vector sum0 = Q6_V_vzero();
|
||||
HVX_Vector sum1 = Q6_V_vzero();
|
||||
|
||||
while (src0_vec_ptr_end - src0_vec_ptr > 1) {
|
||||
HVX_Vector curr0_lo = src0_vec_ptr[0];
|
||||
HVX_Vector curr0_hi = src0_vec_ptr[1];
|
||||
HVX_Vector curr1_lo = src1_vec_ptr[0];
|
||||
HVX_Vector curr1_hi = src1_vec_ptr[1];
|
||||
src0_vec_ptr += 2;
|
||||
src1_vec_ptr += 2;
|
||||
|
||||
sum0 = _AddFunc(_MpyFunc(curr0_lo, curr1_lo), sum0);
|
||||
sum1 = _AddFunc(_MpyFunc(curr0_hi, curr1_hi), sum1);
|
||||
}
|
||||
|
||||
return _ReduceFunc(_AddFunc(sum0, sum1));
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_mpy_qf32(HVX_Vector src0, HVX_Vector src1) {
|
||||
return Q6_Vqf32_vmpy_VsfVsf(src0, src1);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_add_qf32(HVX_Vector sum, HVX_Vector result) {
|
||||
return Q6_Vqf32_vadd_Vqf32Vqf32(sum, result);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_mpy_qf16(HVX_Vector src0, HVX_Vector src1) {
|
||||
return Q6_Vqf16_vmpy_VhfVhf(src0, src1);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) {
|
||||
return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_impl<float, vec_mpy_qf32, vec_add_qf32, hexagon::vec_reduction_qf32_f32>(src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) {
|
||||
return vec_dot_product_aligned_impl<float, vec_mpy_qf32, vec_add_qf32, hexagon::vec_reduction_qf32_f32>(src0, src1,
|
||||
count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
||||
return vec_dot_product_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, hexagon::vec_reduction_qf16_f32>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) {
|
||||
return vec_dot_product_aligned_impl<npu_device_fp16_t, vec_mpy_qf16, vec_add_qf16, hexagon::vec_reduction_qf16_f32>(
|
||||
src0, src1, count);
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -0,0 +1,274 @@
|
|||
#pragma once
|
||||
|
||||
#include <hexagon_types.h>
|
||||
#include <HTP/core/intrinsics.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "hexagon_npu.h"
|
||||
|
||||
namespace hexagon {
|
||||
|
||||
constexpr const size_t kBytesPerVector = sizeof(HVX_Vector); // 128 for v73
|
||||
constexpr const size_t kAlignMask = kBytesPerVector - 1;
|
||||
|
||||
inline size_t unaligned_bytes(const void * addr) {
|
||||
return ((size_t) addr) & kAlignMask;
|
||||
}
|
||||
|
||||
inline bool is_addr_aligned(const void * addr) {
|
||||
return unaligned_bytes(addr) == 0;
|
||||
}
|
||||
|
||||
inline float get_flt0_from_fltv(HVX_Vector vect) {
|
||||
static_assert(sizeof(vect[0]) == sizeof(float), "vect[0] should be a float");
|
||||
int32_t i = vect[0];
|
||||
return reinterpret_cast<float &>(i);
|
||||
}
|
||||
|
||||
inline HVX_UVector Q6_V_vmemu_R(const void * unaligned_ptr) {
|
||||
return *reinterpret_cast<const HVX_UVector *>(unaligned_ptr);
|
||||
}
|
||||
|
||||
inline HVX_Vector Q6_V_vmem_R(const void * aligned_ptr) {
|
||||
return *reinterpret_cast<const HVX_Vector *>(aligned_ptr);
|
||||
}
|
||||
|
||||
constexpr const size_t kL2CacheSize = 8 * 1024; // // 8KB L2 cache
|
||||
constexpr const size_t kL2FetchAheadVectors = kL2CacheSize / kBytesPerVector;
|
||||
|
||||
inline void l2fetch(const void * p, uint32_t stride, uint32_t width, uint32_t height, uint32_t dir) {
|
||||
uint64_t control = HEXAGON_V64_CREATE_H(dir, stride, width, height);
|
||||
__asm__ __volatile__(" l2fetch(%0,%1) " : : "r"(p), "r"(control));
|
||||
}
|
||||
|
||||
inline void l2fetch_row(const uint8_t * row_ptr, size_t bytes) {
|
||||
// TODO: should we use small kL2FetchAheadVectors?
|
||||
int32_t l2fetch_vectors = Q6_R_min_RR(bytes / kBytesPerVector, kL2FetchAheadVectors);
|
||||
hexagon::l2fetch(row_ptr, kBytesPerVector, kBytesPerVector, l2fetch_vectors, 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* This function converts a vector of IEEE float elements to a vector of qf32 elements
|
||||
* See also: libs\qfe\inc\qhmath_hvx_convert.h
|
||||
*/
|
||||
inline HVX_Vector qhmath_hvx_vqf32_convert_vsf(HVX_Vector vin) {
|
||||
return Q6_Vqf32_vadd_VsfVsf(vin, Q6_V_vzero());
|
||||
}
|
||||
|
||||
/*
|
||||
* This function converts a vector of IEEE half float elements to a vector of qf16 elements
|
||||
* See also: libs\qfe\inc\qhmath_hvx_convert.h
|
||||
*/
|
||||
inline HVX_Vector qhmath_hvx_vqf16_convert_vhf(HVX_Vector vin) {
|
||||
return Q6_Vqf16_vadd_VhfVhf(vin, Q6_V_vzero());
|
||||
}
|
||||
|
||||
/*
|
||||
* This function converts a pair of vectors of qf32 elements to a vector of IEEE half float elements
|
||||
* See also: libs\qfe\inc\qhmath_hvx_convert.h
|
||||
*/
|
||||
inline HVX_Vector qhmath_hvx_vhf_convert_vqf32(HVX_VectorPair vin_vp) {
|
||||
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(vin_vp));
|
||||
}
|
||||
|
||||
/*
|
||||
* This function converts a vector of qf16 elements to a pair of vectors of qf32 elements
|
||||
* See also: libs\qfe\inc\qhmath_hvx_convert.h
|
||||
*/
|
||||
inline HVX_VectorPair qhmath_hvx_vqf32_convert_vqf16(HVX_Vector vxl) {
|
||||
HVX_VectorPair vxw_vp, exponent_vp;
|
||||
HVX_Vector mantissa_mask = Q6_Vh_vsplat_R(0xffe0);
|
||||
HVX_Vector exp_mask = Q6_Vh_vsplat_R(0x1f);
|
||||
HVX_Vector exp_offset = Q6_Vh_vsplat_R(0x70);
|
||||
HVX_Vector mant32_shift = Q6_Vh_vsplat_R(0x10);
|
||||
HVX_Vector reql, reqh, vxl_w, vxh_w, mantissa;
|
||||
HVX_Vector el_exponent, eh_exponent;
|
||||
|
||||
el_exponent = Q6_V_vand_VV(exp_mask, vxl);
|
||||
// Obtain the mantissa part: bits (5-15)
|
||||
mantissa = Q6_V_vand_VV(mantissa_mask, vxl);
|
||||
// Convert qf16 biassed exponent to qf32 biased exponent
|
||||
// new exp = exp + ( 127 (qf32 bias) -15(qf16 biass) ) = 112
|
||||
el_exponent = Q6_Vh_vadd_VhVh(exp_offset, el_exponent);
|
||||
|
||||
vxw_vp = Q6_Ww_vunpack_Vh(mantissa);
|
||||
vxl_w = Q6_V_lo_W(vxw_vp);
|
||||
vxh_w = Q6_V_hi_W(vxw_vp);
|
||||
|
||||
exponent_vp = Q6_Ww_vunpack_Vh(el_exponent);
|
||||
el_exponent = Q6_V_lo_W(exponent_vp);
|
||||
eh_exponent = Q6_V_hi_W(exponent_vp);
|
||||
// Convert q16 mantiss to q32 mantissa
|
||||
reql = Q6_Vw_vasl_VwVw(vxl_w, mant32_shift);
|
||||
reqh = Q6_Vw_vasl_VwVw(vxh_w, mant32_shift);
|
||||
// Add the exponent
|
||||
vxl_w = Q6_Vw_vadd_VwVw(reql, el_exponent);
|
||||
vxh_w = Q6_Vw_vadd_VwVw(reqh, eh_exponent);
|
||||
|
||||
return Q6_W_vcombine_VV(vxh_w, vxl_w);
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) {
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
|
||||
static_assert(kFloatsPerVector == 32 || kFloatsPerVector == 16, "kFloatsPerVector should be 16 or 32");
|
||||
|
||||
// TODO: do we have a better way to do the reduction?
|
||||
switch (kFloatsPerVector) {
|
||||
default:
|
||||
case 32:
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 16 * sizeof(float)));
|
||||
// fallthrough
|
||||
case 16:
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 8 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 4 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 2 * sizeof(float)));
|
||||
sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, sizeof(float)));
|
||||
break;
|
||||
}
|
||||
|
||||
return sums;
|
||||
}
|
||||
|
||||
inline float vec_reduction_qf32_f32(HVX_Vector sums) {
|
||||
return get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec_reduction_qf32(sums)));
|
||||
}
|
||||
|
||||
inline HVX_Vector vec_reduction_qf16(HVX_Vector sums) {
|
||||
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(npu_device_fp16_t);
|
||||
static_assert(kFloatsPerVector == 64 || kFloatsPerVector == 32, "kFloatsPerVector should be 32 or 64");
|
||||
|
||||
// TODO: do we have a better way to do the reduction?
|
||||
switch (kFloatsPerVector) {
|
||||
default:
|
||||
case 64:
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 32 * sizeof(npu_device_fp16_t)));
|
||||
// fallthrough
|
||||
case 32:
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 16 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 8 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 4 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 2 * sizeof(npu_device_fp16_t)));
|
||||
sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, sizeof(npu_device_fp16_t)));
|
||||
break;
|
||||
}
|
||||
|
||||
return sums;
|
||||
}
|
||||
|
||||
inline float vec_reduction_qf16_f32(HVX_Vector sums) {
|
||||
HVX_Vector vect = Q6_Vhf_equals_Vqf16(vec_reduction_qf16(sums));
|
||||
uint16_t i = (vect[0] & 0xffff);
|
||||
return reinterpret_cast<__fp16 &>(i);
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_scale_f32(float scale) {
|
||||
return Q6_V_vsplat_R(reinterpret_cast<const uint32_t &>(scale));
|
||||
}
|
||||
|
||||
template <HVX_Vector (*_Func)(HVX_Vector, HVX_UVector *, HVX_Vector), HVX_Vector (*_FuncScaleConvert)(float),
|
||||
typename _TParam>
|
||||
inline void vec_scale_impl(const _TParam * src, float scale, _TParam * dst, size_t count) {
|
||||
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TParam);
|
||||
|
||||
HVX_Vector * src_vec_ptr = ((HVX_Vector *) src);
|
||||
HVX_Vector * const src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector);
|
||||
HVX_UVector * dst_vec_ptr = ((HVX_UVector *) dst); // TODO: opt the unaligned case?
|
||||
HVX_Vector prev = *src_vec_ptr++;
|
||||
const size_t leftover = count % kElementsPerVector;
|
||||
const size_t leftover_bytes = leftover * sizeof(_TParam);
|
||||
|
||||
HVX_Vector scale_vec = _FuncScaleConvert(scale);
|
||||
|
||||
while (src_vec_end - src_vec_ptr > 1) {
|
||||
HVX_VectorPair curr = reinterpret_cast<HVX_VectorPair *>(src_vec_ptr)[0];
|
||||
src_vec_ptr += 2;
|
||||
|
||||
HVX_Vector lo = Q6_V_valign_VVR(Q6_V_lo_W(curr), prev, (size_t) src);
|
||||
HVX_Vector hi = Q6_V_valign_VVR(Q6_V_hi_W(curr), Q6_V_lo_W(curr), (size_t) src);
|
||||
|
||||
dst_vec_ptr[0] = _Func(lo, dst_vec_ptr, scale_vec);
|
||||
dst_vec_ptr[1] = _Func(hi, dst_vec_ptr + 1, scale_vec);
|
||||
|
||||
dst_vec_ptr += 2;
|
||||
prev = Q6_V_hi_W(curr);
|
||||
}
|
||||
|
||||
if (src_vec_end - src_vec_ptr > 0) {
|
||||
HVX_Vector curr = *src_vec_ptr++;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec);
|
||||
dst_vec_ptr++;
|
||||
prev = curr;
|
||||
}
|
||||
|
||||
if ((src_vec_end - ((HVX_Vector *) src)) > 0) {
|
||||
// handle the last vector
|
||||
bool should_fetch_next = leftover == 0 && hexagon::is_addr_aligned(src_vec_ptr);
|
||||
HVX_Vector curr = should_fetch_next ? prev : *src_vec_ptr;
|
||||
src_vec_ptr = should_fetch_next ? src_vec_ptr : src_vec_ptr + 1;
|
||||
HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec);
|
||||
dst_vec_ptr++;
|
||||
prev = curr;
|
||||
}
|
||||
|
||||
if (leftover > 0) {
|
||||
// handle the leftover elements
|
||||
HVX_Vector curr =
|
||||
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
|
||||
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
|
||||
q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _Func(curr, dst_vec_ptr, scale_vec));
|
||||
}
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_vec_scale_f32_f32(HVX_Vector src, HVX_UVector *, HVX_Vector scale_vec) {
|
||||
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(src, scale_vec));
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_vec_mad_f32_f32(HVX_Vector src, HVX_UVector * dst_ptr, HVX_Vector scale_vec) {
|
||||
HVX_Vector dst = *dst_ptr; // TODO: opt the unaligned case?
|
||||
src = Q6_Vqf32_vmpy_VsfVsf(src, scale_vec);
|
||||
src = Q6_Vqf32_vadd_Vqf32Vsf(src, dst);
|
||||
return Q6_Vsf_equals_Vqf32(src);
|
||||
}
|
||||
|
||||
inline void vec_scale_f32(const float * src, float scale, float * dst, size_t count) {
|
||||
vec_scale_impl<hvx_vec_scale_f32_f32, hvx_scale_f32, float>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
inline void vec_mad_f32(const float * src, float scale, float * dst, size_t count) {
|
||||
vec_scale_impl<hvx_vec_mad_f32_f32, hvx_scale_f32, float>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_scale_f16(float scale) {
|
||||
__fp16 f16_scale = scale;
|
||||
return Q6_Vh_vsplat_R(reinterpret_cast<const npu_device_fp16_t &>(f16_scale));
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_vec_scale_f16_f16(HVX_Vector src, HVX_UVector *, HVX_Vector scale_vec) {
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(src, scale_vec));
|
||||
}
|
||||
|
||||
inline HVX_Vector hvx_vec_mad_f16_f16(HVX_Vector src, HVX_UVector * dst_ptr, HVX_Vector scale_vec) {
|
||||
HVX_Vector dst = *dst_ptr; // TODO: opt the unaligned case?
|
||||
HVX_Vector scaled = Q6_Vqf16_vmpy_VhfVhf(src, scale_vec);
|
||||
HVX_Vector result = Q6_Vqf16_vadd_Vqf16Vhf(scaled, dst);
|
||||
return Q6_Vhf_equals_Vqf16(result);
|
||||
}
|
||||
|
||||
inline void vec_scale_f16(const npu_device_fp16_t * src, float scale, npu_device_fp16_t * dst, size_t count) {
|
||||
vec_scale_impl<hvx_vec_scale_f16_f16, hvx_scale_f16, npu_device_fp16_t>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
inline void vec_mad_f16(const npu_device_fp16_t * src, float scale, npu_device_fp16_t * dst, size_t count) {
|
||||
vec_scale_impl<hvx_vec_mad_f16_f16, hvx_scale_f16, npu_device_fp16_t>(src, scale, dst, count);
|
||||
}
|
||||
|
||||
float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count);
|
||||
float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count);
|
||||
|
||||
float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count);
|
||||
float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count);
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
@ -114,7 +114,9 @@ size_t backend_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|||
size_t backend_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
auto * buffer_type_obj = get_buffer_type_object(buft);
|
||||
GGML_ASSERT(buffer_type_obj != nullptr);
|
||||
return buffer_type_obj->get_max_buffer_size();
|
||||
auto size = buffer_type_obj->get_max_buffer_size();
|
||||
LOG_DEBUG("[hexagon-npu][%s]max_buffer_size: %zu\n", buffer_type_obj->get_name(), size);
|
||||
return size;
|
||||
}
|
||||
|
||||
bool backend_buffer_is_host(ggml_backend_buffer_type_t buft) {
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ bool host_graph::update(ggml_cgraph * cgraph) {
|
|||
return false;
|
||||
}
|
||||
|
||||
LOG_DEBUG("[%p]host_graph::update started\n", (void *) this);
|
||||
|
||||
SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]update, handle(%p)", (void *) this, (void *) _graph_handle);
|
||||
|
||||
_tensor_handles.clear();
|
||||
|
|
@ -40,8 +42,9 @@ bool host_graph::update(ggml_cgraph * cgraph) {
|
|||
if (node->op == GGML_OP_NONE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE ||
|
||||
node->op == GGML_OP_RESHAPE) {
|
||||
// skip view liked ops
|
||||
LOG_DEBUG("node[%d]%s(%s), addr: %p, type: %s, skipped\n", i, ggml_get_name(node), ggml_op_desc(node),
|
||||
(void *) node, ggml_type_name(node->type));
|
||||
LOG_DEBUG("node[%d]%s(%s), addr: %p, type: %s, dims: %ldx%ldx%ldx%ld, skipped\n", i, ggml_get_name(node),
|
||||
ggml_op_desc(node), (void *) node, ggml_type_name(node->type), (long) node->ne[0],
|
||||
(long) node->ne[1], (long) node->ne[2], (long) node->ne[3]);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -54,9 +57,10 @@ bool host_graph::update(ggml_cgraph * cgraph) {
|
|||
|
||||
_tensor_handles.push_back(tensor_obj->get_device_tensor_handle());
|
||||
_tensor_update_configs.push_back(tensor_obj->update_hosts_params_only(node));
|
||||
LOG_DEBUG("[%p]node[%d]%s(%s), addr: %p, type: %s, tensor_handle: %p\n", (void *) this, i, ggml_get_name(node),
|
||||
ggml_op_desc(node), (void *) node, ggml_type_name(node->type),
|
||||
(void *) tensor_obj->get_device_tensor_handle());
|
||||
LOG_DEBUG("node[%d]%s(%s), addr: %p, type: %s, dims: %ldx%ldx%ldx%ld, tensor_handle: %p\n", i,
|
||||
ggml_get_name(node), ggml_op_desc(node), (void *) node, ggml_type_name(node->type),
|
||||
(long) tensor_obj->get_ne(0), (long) tensor_obj->get_ne(1), (long) tensor_obj->get_ne(2),
|
||||
(long) tensor_obj->get_ne(3), (void *) tensor_obj->get_device_tensor_handle());
|
||||
}
|
||||
|
||||
GGML_ASSERT(_tensor_handles.size() == _tensor_update_configs.size());
|
||||
|
|
@ -71,7 +75,7 @@ bool host_graph::update(ggml_cgraph * cgraph) {
|
|||
(int) _tensor_update_configs.size());
|
||||
|
||||
if (ret != AEE_SUCCESS) {
|
||||
LOG_ERROR("Failed to set tensors in host_graph: 0x%x\n", (int) ret);
|
||||
LOG_ERROR("[%p]failed to set tensors in host_graph: 0x%x\n", (void *) this, (int) ret);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -149,37 +149,17 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto * src0 = op->src[0];
|
||||
if (!src0) {
|
||||
LOG_DEBUG("[%s]Unsupported inplace op: %s\n", get_name(), ggml_op_desc(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (type_to_npu_type(src0->type) == NPU_DATA_TYPE_COUNT) {
|
||||
LOG_DEBUG("[%s]Unsupported src0 tensor type: %s\n", get_name(), ggml_type_name(src0->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
auto * src1 = op->src[1];
|
||||
if (src1 && type_to_npu_type(src1->type) == NPU_DATA_TYPE_COUNT) {
|
||||
LOG_DEBUG("[%s]Unsupported src1 tensor type: %s\n", get_name(), ggml_type_name(src1->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
auto npu_op = op_to_npu_op(op->op);
|
||||
if (npu_op == NPU_OP_COUNT) {
|
||||
LOG_DEBUG("[%s]Unsupported op: %s\n", get_name(), ggml_op_desc(op));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!_device_handle && !init_device()) {
|
||||
LOG_DEBUG("[%s]NPU device initialization failed\n", get_name());
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr const auto get_spec = [](const ggml_tensor * tensor) -> npu_device_tensor_spec {
|
||||
int i = 0;
|
||||
npu_device_tensor_spec srcs[DEVICE_TENSOR_MAX_SRC] = {};
|
||||
constexpr const auto get_spec = [](const ggml_tensor * tensor) -> npu_device_tensor_spec {
|
||||
if (!tensor) {
|
||||
return npu_device_tensor_spec{ {}, NPU_DATA_TYPE_COUNT };
|
||||
return npu_device_tensor_spec{ {}, {}, NPU_DATA_TYPE_COUNT };
|
||||
}
|
||||
|
||||
static_assert(DEVICE_TENSOR_MAX_DIMS == GGML_MAX_DIMS, "tensor dimensions mismatch");
|
||||
|
|
@ -188,19 +168,40 @@ bool npu_device::supports_op_impl(const ggml_tensor * op) {
|
|||
spec.ne[1] = tensor->ne[1];
|
||||
spec.ne[2] = tensor->ne[2];
|
||||
spec.ne[3] = tensor->ne[3];
|
||||
|
||||
spec.nb[0] = tensor->nb[0];
|
||||
spec.nb[1] = tensor->nb[1];
|
||||
spec.nb[2] = tensor->nb[2];
|
||||
spec.nb[3] = tensor->nb[3];
|
||||
spec.type = type_to_npu_type(tensor->type);
|
||||
return spec;
|
||||
};
|
||||
|
||||
for (; i < (int) DEVICE_TENSOR_MAX_SRC && op->src[i]; ++i) {
|
||||
auto * src = op->src[i];
|
||||
if (type_to_npu_type(src->type) == NPU_DATA_TYPE_COUNT) {
|
||||
LOG_DEBUG("[%s]Unsupported src%d tensor type: %s\n", get_name(), i, ggml_type_name(src->type));
|
||||
return false;
|
||||
}
|
||||
|
||||
srcs[i] = get_spec(src);
|
||||
}
|
||||
|
||||
if (!_device_handle && !init_device()) {
|
||||
LOG_DEBUG("[%s]NPU device initialization failed\n", get_name());
|
||||
return false;
|
||||
}
|
||||
|
||||
boolean supported = false;
|
||||
auto src0_spec = get_spec(src0);
|
||||
auto src1_spec = get_spec(src1);
|
||||
auto dst_spec = get_spec(op);
|
||||
auto ret = npu_device_device_support_op(_device_handle, &src0_spec, &src1_spec, &dst_spec, npu_op, &supported);
|
||||
auto ret = npu_device_device_support_op(_device_handle, npu_op, &dst_spec, srcs, i, &supported);
|
||||
if (ret != AEE_SUCCESS || !supported) {
|
||||
#ifndef NDEBUG
|
||||
auto * src0_type = i ? ggml_type_name(op->src[0]->type) : "null";
|
||||
auto * src1_type = (i > 1) ? ggml_type_name(op->src[1]->type) : "null";
|
||||
LOG_DEBUG("[%s][%s]unsupported %s(%s,%s), ret: 0x%x, supported: %d\n", get_name(), ggml_op_name(op->op),
|
||||
ggml_type_name(op->type), ggml_type_name(src0->type), (src1 ? ggml_type_name(src1->type) : "null"),
|
||||
ret, supported);
|
||||
ggml_type_name(op->type), src0_type, src1_type, ret, supported);
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,15 +21,17 @@ class host_tensor {
|
|||
|
||||
explicit host_tensor(ggml_tensor * tensor, int buffer_fd, uint64_t offset, remote_handle64 device_handle) :
|
||||
_device_handle(device_handle) {
|
||||
|
||||
// TODO: figure out why the npu_device_tensor_config can't be larger than 100 bytes
|
||||
static_assert(sizeof(npu_device_tensor_config) < 100, "npu_device_tensor_config size too large");
|
||||
static_assert(sizeof(npu_device_tensor_config) < kMaxNpuRpcStructSize,
|
||||
"npu_device_tensor_config size too large");
|
||||
|
||||
_info.buffer_fd = buffer_fd;
|
||||
_info.offset = offset;
|
||||
_info.type = type_to_npu_type(tensor->type);
|
||||
_info.size = ggml_nbytes(tensor);
|
||||
_info.buffer_fd = buffer_fd;
|
||||
_info.offset = offset;
|
||||
_info.type = type_to_npu_type(tensor->type);
|
||||
_info.size = ggml_nbytes(tensor);
|
||||
_info.is_constant = false; // TODO: support constant tensors in the future
|
||||
// _info.op will be updated in update_params()
|
||||
_info_update.op = NPU_OP_COUNT;
|
||||
|
||||
static_assert(DEVICE_TENSOR_MAX_DIMS == GGML_MAX_DIMS, "tensor dimensions mismatch");
|
||||
static_assert(sizeof(_info.ne) == sizeof(tensor->ne), "tensor ne size mismatch");
|
||||
|
|
@ -46,10 +48,10 @@ class host_tensor {
|
|||
|
||||
tensor->extra = this;
|
||||
_ggml_tensor = tensor;
|
||||
LOG_DEBUG("host_tensor(%p), ggml_tensor(%p[%ldx%ldx%ldx%ld], nb[%ld][%ld][%ld][%ld], %s), handle(%p)\n",
|
||||
(void *) this, (void *) tensor, (long) tensor->ne[0], (long) tensor->ne[1], (long) tensor->ne[2],
|
||||
LOG_DEBUG("host_tensor(%p), ggml_tensor(%s[%ldx%ldx%ldx%ld], nb[%ld][%ld][%ld][%ld], %s, %p), handle(%p)\n",
|
||||
(void *) this, tensor->name, (long) tensor->ne[0], (long) tensor->ne[1], (long) tensor->ne[2],
|
||||
(long) tensor->ne[3], (long) tensor->nb[0], (long) tensor->nb[1], (long) tensor->nb[2],
|
||||
(long) tensor->nb[3], ggml_type_name(tensor->type), (void *) _device_tensor_handle);
|
||||
(long) tensor->nb[3], ggml_type_name(tensor->type), (void *) tensor, (void *) _device_tensor_handle);
|
||||
}
|
||||
|
||||
~host_tensor() {
|
||||
|
|
@ -76,11 +78,9 @@ class host_tensor {
|
|||
auto new_op = op_to_npu_op(_ggml_tensor->op);
|
||||
bool params_changed = new_op != _info_update.op;
|
||||
if (params_changed) {
|
||||
LOG_DEBUG("host_tensor(%p) op changed: %s -> %s\n", (void *) this, get_npu_op_desc(_info.op),
|
||||
get_npu_op_desc(new_op));
|
||||
LOG_DEBUG("host_tensor(%p) op changed: %s\n", (void *) this, get_npu_op_desc(new_op));
|
||||
}
|
||||
|
||||
_info.op = new_op;
|
||||
_info_update.op = new_op;
|
||||
|
||||
if (memcmp(_info_update.params, _ggml_tensor->op_params, sizeof(_info_update.params)) != 0) {
|
||||
|
|
@ -92,15 +92,16 @@ class host_tensor {
|
|||
}
|
||||
|
||||
npu_device_tensor_handle_t src_tensor_handles[DEVICE_TENSOR_MAX_SRC] = {};
|
||||
for (size_t j = 0; j < DEVICE_TENSOR_MAX_SRC && _ggml_tensor->src[j]; ++j) {
|
||||
auto * src = host_tensor::from_ggml_tensor(_ggml_tensor->src[j]);
|
||||
src_tensor_handles[j] = src->get_device_tensor_handle();
|
||||
LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p\n", (void *) this, j, (void *) src);
|
||||
}
|
||||
|
||||
static_assert(std::is_same<decltype(_info_update.src_handles), decltype(src_tensor_handles)>::value,
|
||||
"src tensor handles type mismatch");
|
||||
|
||||
for (size_t j = 0; j < DEVICE_TENSOR_MAX_SRC && _ggml_tensor->src[j]; ++j) {
|
||||
auto * ggml_src = _ggml_tensor->src[j];
|
||||
auto * src = host_tensor::from_ggml_tensor(ggml_src);
|
||||
src_tensor_handles[j] = src->get_device_tensor_handle();
|
||||
LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p(%s)\n", (void *) this, j, (void *) src, ggml_src->name);
|
||||
}
|
||||
|
||||
if (memcmp(_info_update.src_handles, src_tensor_handles, sizeof(_info_update.src_handles)) != 0) {
|
||||
params_changed = true;
|
||||
memcpy(_info_update.src_handles, src_tensor_handles, sizeof(_info_update.src_handles));
|
||||
|
|
@ -128,14 +129,14 @@ class host_tensor {
|
|||
GGML_ASSERT(ggml_tensor == _ggml_tensor);
|
||||
|
||||
auto new_op = op_to_npu_op(_ggml_tensor->op);
|
||||
_info.op = new_op;
|
||||
_info_update.op = new_op;
|
||||
memcpy(_info_update.params, _ggml_tensor->op_params, sizeof(_info_update.params));
|
||||
|
||||
for (size_t j = 0; j < DEVICE_TENSOR_MAX_SRC && _ggml_tensor->src[j]; ++j) {
|
||||
auto * src = host_tensor::from_ggml_tensor(_ggml_tensor->src[j]);
|
||||
auto * ggml_src = _ggml_tensor->src[j];
|
||||
auto * src = host_tensor::from_ggml_tensor(ggml_src);
|
||||
_info_update.src_handles[j] = src->get_device_tensor_handle();
|
||||
LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p\n", (void *) this, j, (void *) src);
|
||||
LOG_DEBUG("host_tensor(%p) set_src[%zu]: %p(%s)\n", (void *) this, j, (void *) src, ggml_src->name);
|
||||
}
|
||||
|
||||
LOG_DEBUG("host_tensor(%p) update_params, op: %s, params: [%x, %x, %x, %x]\n", (void *) this,
|
||||
|
|
@ -146,6 +147,15 @@ class host_tensor {
|
|||
|
||||
bool is_valid() const { return _device_tensor_handle != 0; }
|
||||
|
||||
int64_t get_ne(size_t index) const {
|
||||
if (index >= DEVICE_TENSOR_MAX_DIMS) {
|
||||
LOG_ERROR("host_tensor(%p) get_ne: index out of bounds: %zu\n", (void *) this, index);
|
||||
return 0;
|
||||
}
|
||||
|
||||
return _info.ne[index];
|
||||
}
|
||||
|
||||
private:
|
||||
remote_handle64 _device_handle = 0;
|
||||
npu_device_tensor_handle_t _device_tensor_handle = 0;
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
#include "ggml-common.h"
|
||||
#undef GGML_COMMON_DECL_CPP
|
||||
|
||||
static_assert(sizeof(npu_device_block_q4_K) == sizeof(block_q4_K), "npu_device_block_q4_K size mismatch");
|
||||
static_assert(sizeof(npu_device_block_q4_k) == sizeof(block_q4_K), "npu_device_block_q4_k size mismatch");
|
||||
static_assert(sizeof(npu_device_block_q4_0) == sizeof(block_q4_0), "npu_device_block_q4_0 size mismatch");
|
||||
static_assert(sizeof(npu_device_block_q8_0) == sizeof(block_q8_0), "npu_device_block_q8_0 size mismatch");
|
||||
static_assert(QUANT_K_SCALE_SIZE == K_SCALE_SIZE, "QUANT_K_SCALE_SIZE size mismatch");
|
||||
|
|
@ -27,6 +27,8 @@ enum npu_device_tensor_op op_to_npu_op(ggml_op op) {
|
|||
return NPU_OP_MUL;
|
||||
case GGML_OP_RMS_NORM:
|
||||
return NPU_OP_RMS_NORM;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return NPU_OP_FLASH_ATTN;
|
||||
default:
|
||||
return NPU_OP_COUNT;
|
||||
}
|
||||
|
|
@ -44,6 +46,8 @@ const char * get_npu_op_desc(enum npu_device_tensor_op op) {
|
|||
return ggml_op_name(GGML_OP_MUL);
|
||||
case NPU_OP_RMS_NORM:
|
||||
return ggml_op_name(GGML_OP_RMS_NORM);
|
||||
case NPU_OP_FLASH_ATTN:
|
||||
return ggml_op_name(GGML_OP_FLASH_ATTN_EXT);
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
|
@ -160,27 +164,65 @@ void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len) {
|
|||
}
|
||||
};
|
||||
|
||||
auto * src0 = dst->src[0];
|
||||
if (src0 == nullptr) {
|
||||
print_tensor(dst, out, max_len);
|
||||
return;
|
||||
}
|
||||
constexpr const auto get_src_tensor_count = [](const ggml_tensor * tensor) -> size_t {
|
||||
for (size_t i = 0; i < GGML_MAX_SRC; ++i) {
|
||||
if (!tensor->src[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return GGML_MAX_SRC;
|
||||
};
|
||||
|
||||
char dst_desc[256];
|
||||
print_tensor(dst, dst_desc, sizeof(dst_desc));
|
||||
|
||||
char src0_desc[256];
|
||||
print_tensor(src0, src0_desc, sizeof(src0_desc));
|
||||
|
||||
auto * src1 = dst->src[1];
|
||||
if (src1 == nullptr) {
|
||||
snprintf(out, max_len, "dst: %s, src0: %s", dst_desc, src0_desc);
|
||||
return;
|
||||
switch (get_src_tensor_count(dst)) {
|
||||
case 4:
|
||||
{
|
||||
char src0_desc[256];
|
||||
print_tensor(dst->src[0], src0_desc, sizeof(src0_desc));
|
||||
char src1_desc[256];
|
||||
print_tensor(dst->src[1], src1_desc, sizeof(src1_desc));
|
||||
char src2_desc[256];
|
||||
print_tensor(dst->src[2], src2_desc, sizeof(src2_desc));
|
||||
char src3_desc[256];
|
||||
print_tensor(dst->src[3], src3_desc, sizeof(src3_desc));
|
||||
snprintf(out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s, src3: %s", dst_desc, src0_desc,
|
||||
src1_desc, src2_desc, src3_desc);
|
||||
return;
|
||||
}
|
||||
case 3:
|
||||
{
|
||||
char src0_desc[256];
|
||||
print_tensor(dst->src[0], src0_desc, sizeof(src0_desc));
|
||||
char src1_desc[256];
|
||||
print_tensor(dst->src[1], src1_desc, sizeof(src1_desc));
|
||||
char src2_desc[256];
|
||||
print_tensor(dst->src[2], src2_desc, sizeof(src2_desc));
|
||||
snprintf(out, max_len, "dst: %s, src0: %s, src1: %s, src2: %s", dst_desc, src0_desc, src1_desc,
|
||||
src2_desc);
|
||||
return;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
char src0_desc[256];
|
||||
print_tensor(dst->src[0], src0_desc, sizeof(src0_desc));
|
||||
char src1_desc[256];
|
||||
print_tensor(dst->src[1], src1_desc, sizeof(src1_desc));
|
||||
snprintf(out, max_len, "dst: %s, src0: %s, src1: %s", dst_desc, src0_desc, src1_desc);
|
||||
return;
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
char src0_desc[256];
|
||||
print_tensor(dst->src[0], src0_desc, sizeof(src0_desc));
|
||||
snprintf(out, max_len, "dst: %s, src0: %s", dst_desc, src0_desc);
|
||||
return;
|
||||
}
|
||||
default:
|
||||
snprintf(out, max_len, "dst: %s", dst_desc);
|
||||
return;
|
||||
}
|
||||
|
||||
char src1_desc[256];
|
||||
print_tensor(src1, src1_desc, sizeof(src1_desc));
|
||||
snprintf(out, max_len, "dst: %s, src0: %s, src1: %s", dst_desc, src0_desc, src1_desc);
|
||||
}
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -26,4 +26,6 @@ void enable_unsigned_dsp_module(common::rpc_interface_ptr rpc_interface, uint32_
|
|||
|
||||
void get_op_tensor_desc(const ggml_tensor * dst, char * out, size_t max_len);
|
||||
|
||||
constexpr const size_t kMaxNpuRpcStructSize = 100; // TODO: figure out the actual size
|
||||
|
||||
} // namespace hexagon
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#include "remote.idl"
|
||||
|
||||
const uint32_t DEVICE_TENSOR_MAX_DIMS = 4;
|
||||
const uint32_t DEVICE_TENSOR_MAX_SRC = 2;
|
||||
const uint32_t DEVICE_TENSOR_MAX_SRC = 4;
|
||||
const uint32_t DEVICE_TENSOR_MAX_OP_PARAMS = 4;
|
||||
const uint32_t QUANT_BLOCK_SIZE = 32;
|
||||
const uint32_t QUANT_K_BLOCK_SIZE = 256;
|
||||
|
|
@ -12,6 +12,7 @@ const uint32_t QUANT_K_SCALE_SIZE = 12;
|
|||
interface npu_device : remote_handle64{
|
||||
|
||||
typedef int64_t ne_type[DEVICE_TENSOR_MAX_DIMS];
|
||||
typedef uint64_t nb_type[DEVICE_TENSOR_MAX_DIMS];
|
||||
typedef uint64_t tensor_handle_t;
|
||||
typedef uint64_t graph_handle_t;
|
||||
|
||||
|
|
@ -22,7 +23,7 @@ interface npu_device : remote_handle64{
|
|||
uint8_t qs[QUANT_BLOCK_SIZE / 2];
|
||||
};
|
||||
|
||||
struct block_q4_K {
|
||||
struct block_q4_k {
|
||||
fp16_t d;
|
||||
fp16_t dmin;
|
||||
uint8_t scales[QUANT_K_SCALE_SIZE];
|
||||
|
|
@ -40,6 +41,7 @@ interface npu_device : remote_handle64{
|
|||
NPU_OP_SUB,
|
||||
NPU_OP_MUL,
|
||||
NPU_OP_RMS_NORM,
|
||||
NPU_OP_FLASH_ATTN,
|
||||
NPU_OP_COUNT
|
||||
};
|
||||
|
||||
|
|
@ -54,6 +56,7 @@ interface npu_device : remote_handle64{
|
|||
|
||||
struct tensor_spec {
|
||||
ne_type ne;
|
||||
nb_type nb;
|
||||
tensor_data_type type;
|
||||
};
|
||||
|
||||
|
|
@ -65,12 +68,12 @@ interface npu_device : remote_handle64{
|
|||
|
||||
struct tensor_config {
|
||||
ne_type ne;
|
||||
uint64_t nb[DEVICE_TENSOR_MAX_DIMS];
|
||||
nb_type nb;
|
||||
long buffer_fd;
|
||||
uint64_t offset;
|
||||
uint64_t size;
|
||||
tensor_data_type type;
|
||||
tensor_op op;
|
||||
boolean is_constant;
|
||||
};
|
||||
|
||||
AEEResult device_get_alignment(
|
||||
|
|
@ -78,10 +81,9 @@ interface npu_device : remote_handle64{
|
|||
);
|
||||
|
||||
AEEResult device_support_op(
|
||||
in tensor_spec src0,
|
||||
in tensor_spec src1,
|
||||
in tensor_spec dst,
|
||||
in tensor_op op,
|
||||
in tensor_spec dst,
|
||||
in sequence<tensor_spec> srcs,
|
||||
rout boolean is_supported
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ if(GGML_QNN_ENABLE_HEXAGON_BACKEND)
|
|||
if(DEFINED ENV{QNN_SDK_PATH})
|
||||
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
|
||||
message("found HEXAGON_SDK_ROOT, setting to ${HEXAGON_SDK_ROOT}")
|
||||
elseif(EXISTS ${HEXAGON_SDK_ROOT})
|
||||
message("HEXAGON_SDK_ROOT: ${HEXAGON_SDK_ROOT}")
|
||||
else()
|
||||
message(FATAL_ERROR "HEXAGON_SDK_ROOT not defined")
|
||||
endif()
|
||||
|
|
|
|||
Loading…
Reference in New Issue