llama.cpp/ggml/src/ggml-qnn/npu/device/op/op_glu.cpp

234 lines
8.9 KiB
C++

#include "op_glu.hpp"
#include "type_traits.hpp"
#include "util.hpp"
namespace {
template <typename T> struct get_data_type {};
template <typename _TyData, typename _TyParam>
struct get_data_type<void (*)(const _TyData *, const _TyData *, _TyData *, size_t, _TyParam)> {
using type = _TyData;
using param_type = typename std::remove_cv<typename std::remove_reference<_TyParam>::type>::type;
};
inline float dummy_load_coeff() {
// This is a dummy function to satisfy the template requirements.
// In practice, this should be replaced with a proper coefficient loading function.
return 0;
}
inline float expf_f16_guard_inf(float x) {
// Avoid overflow for large values, f16: log(65504)
constexpr float kMaxExp = 11.0898664f;
if (x >= kMaxExp) {
// Avoid overflow for large values
return std::numeric_limits<float>::infinity();
}
return std::expf(x);
}
inline void glu_vec_op_f16_f16(const __fp16 * src0, const __fp16 * src1, __fp16 * dst, size_t count, float coeff) {
// TODO: use simd version, for some input hexagon intrinsics will generate nan instead of inf.
for (uint32_t i = 0; i < count; ++i) {
float x = src0[i];
float g = src1[i];
dst[i] = (x / (1.0f + expf_f16_guard_inf(-x))) * g;
}
}
inline void glu_vec_op_f32_f32(const float * src0,
const float * src1,
float * dst,
size_t count,
hexagon::HVX_VectorPair_x4 coeff) {
using namespace hexagon::vec;
vec_trans_with_param_impl<float, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f32_f32>(
src0, src1, dst, count, coeff);
}
template <auto _GluRowFunc, auto _CoeffLoadFunc>
bool glu_impl(hexagon::tensor * out, hexagon::compute_params * params) {
using data_type = typename get_data_type<decltype(_GluRowFunc)>::type;
using param_type = typename get_data_type<decltype(_GluRowFunc)>::param_type;
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "element_wise_op requires max dims 4");
static_assert(std::is_same_v<param_type, decltype(_CoeffLoadFunc())>,
"GluRowFunc must have the same param type as CoeffLoadFunc");
if (!out) {
return false;
}
const bool has_src1 = out->get_src(1) != nullptr;
auto * src0 = out->get_src(0);
auto * src1 = has_src1 ? out->get_src(1) : src0;
if (!src0 || !src1) {
return true; // skip if no src
}
const auto total_cols = has_src1 ? src0->get_ne(0) : src0->get_ne(0) / 2;
if (out->get_ne(0) != total_cols) {
DEVICE_LOG_ERROR(
"[hexagon-npu][GLU]out.ne[0] (%ld) != total_cols (%d)\n", (long) out->get_ne(0), (int) total_cols);
return false;
}
auto total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
const auto rows_per_cube = out->get_ne(2) * out->get_ne(1);
const auto start_end = params->get_work_slice(total_rows);
if (start_end.first >= start_end.second) {
return true;
}
uint8_t * dst_ptr = out->get_write_buffer();
if (!dst_ptr) {
DEVICE_LOG_ERROR("[hexagon-npu][GLU]glu_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
(void *) out,
hexagon::get_type_name(out->get_type()));
return false;
}
const int32_t swapped = out->get_op_param<int32_t>(1);
const uint8_t * src0_ptr = src0->get_read_buffer();
const uint8_t * src1_ptr = has_src1 ? src1->get_read_buffer() : (src0_ptr + total_cols * sizeof(data_type));
if (swapped) {
std::swap(src0_ptr, src1_ptr);
}
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
auto coeff = _CoeffLoadFunc();
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type);
for (int64_t ir = start_end.first; ir < start_end.second; ++ir) {
const auto i03 = ir / rows_per_cube;
const auto i02 = ir / out->get_ne(1) - i03 * out->get_ne(2);
const auto i01 = ir % out->get_ne(1); // TODO: should we use divide instead of mod?
const auto i13 = i03 % src1->get_ne(3);
const auto i12 = i02 % src1->get_ne(2);
const auto i11 = i01 % src1->get_ne(1);
auto * src1_plane = src1_ptr + i13 * src1->get_nb(3) + i12 * src1->get_nb(2);
auto * src0_row = src0_ptr + i03 * src0->get_nb(3) + i02 * src0->get_nb(2) + i01 * src0->get_nb(1);
auto * src1_row = src1_plane + i11 * src1->get_nb(1);
auto * dst_row = dst_ptr + i03 * out->get_nb(3) + i02 * out->get_nb(2) + i01 * out->get_nb(1);
if (ir + 1 < start_end.second) {
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row_bytes);
}
_GluRowFunc(reinterpret_cast<const data_type *>(src0_row),
reinterpret_cast<const data_type *>(src1_row),
reinterpret_cast<data_type *>(dst_row),
static_cast<size_t>(total_cols),
coeff);
}
out->release_write_buffer(); // mark the output tensor as modified
return true;
}
template <npu_device_tensor_data_type _DataType>
bool glu_compute(hexagon::tensor * out, hexagon::compute_params * params) {
using namespace hexagon::vec::math;
if (out->get_op_param<int32_t>(0) != NPU_GLU_OP_SWIGLU) {
DEVICE_LOG_ERROR("Invalid GLU op type: %d\n", (int) out->get_op_param<int32_t>(0));
return false;
}
if (out->get_type() != _DataType) {
DEVICE_LOG_ERROR("GLU op type mismatch: %s vs %s\n",
hexagon::get_type_name(out->get_type()),
hexagon::get_type_name(_DataType));
return false;
}
if constexpr (_DataType == NPU_DATA_TYPE_F32) {
return glu_impl<glu_vec_op_f32_f32, qhmath_load_div_sf_ltu>(out, params);
} else if constexpr (_DataType == NPU_DATA_TYPE_F16) {
return glu_impl<glu_vec_op_f16_f16, dummy_load_coeff>(out, params);
}
DEVICE_LOG_ERROR("Unsupported GLU data type: %s\n", hexagon::get_type_name(out->get_type()));
return true;
}
} // namespace
namespace hexagon {
bool glu_f32(hexagon::tensor * out, hexagon::compute_params * params) {
return glu_compute<npu_device_tensor_data_type::NPU_DATA_TYPE_F32>(out, params);
}
bool glu_f16(hexagon::tensor * out, hexagon::compute_params * params) {
return glu_compute<npu_device_tensor_data_type::NPU_DATA_TYPE_F16>(out, params);
}
bool is_glu_op_supported(const npu_device_tensor_op_spec * op_spec,
const npu_device_tensor_spec * dst,
const npu_device_tensor_spec * srcs,
size_t src_len) {
const auto op = op_spec->op;
if (op != NPU_OP_GLU) {
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
return false;
}
if (op_spec->params[0] != NPU_GLU_OP_SWIGLU) {
DEVICE_LOG_DEBUG("[%s]unsupported GLU op type: %d\n", hexagon::op_get_name(op), (int) op_spec->params[0]);
return false;
}
if (!dst || !srcs || src_len < 1) {
DEVICE_LOG_DEBUG("[%s]invalid dst or srcs\n", hexagon::op_get_name(op));
return false;
}
const auto & src0 = srcs[0];
if (dst->type != src0.type) {
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n",
hexagon::op_get_name(op),
hexagon::get_type_name(src0.type),
hexagon::get_type_name(dst->type));
return false;
}
if (dst->type != NPU_DATA_TYPE_F32 && dst->type != NPU_DATA_TYPE_F16) {
DEVICE_LOG_DEBUG(
"[%s]unsupported data type: %s\n", hexagon::op_get_name(op), hexagon::get_type_name(dst->type));
return false;
}
if (src_len > 1) {
if (!hexagon::is_same_shape(src0, *dst) || !hexagon::is_same_shape(srcs[1], *dst)) {
DEVICE_LOG_DEBUG("[%s]src0, src1 and dst have different shape\n", hexagon::op_get_name(op));
return false; // src0 and src1 have the same shape as dst
}
} else {
static_assert(DEVICE_TENSOR_MAX_DIMS == 4, "GLU requires max dims 4");
if (src0.ne[0] / 2 != dst->ne[0] || src0.ne[1] != dst->ne[1] || src0.ne[2] != dst->ne[2] ||
src0.ne[3] != dst->ne[3]) {
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape: src0.ne[0]: %ld, dst.ne[0]: %ld\n",
hexagon::op_get_name(op),
(long) src0.ne[0],
(long) dst->ne[0]);
return false;
}
}
return true;
}
bool is_glu_required_sync(const npu_device_tensor_op op, const npu_device_tensor_op next_op) {
NPU_UNUSED(op);
return next_op == NPU_OP_MUL_MAT;
}
} // namespace hexagon