1512 lines
59 KiB
C++
1512 lines
59 KiB
C++
// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
// SPDX-License-Identifier: MIT
|
|
//
|
|
#include <arm_neon.h>
|
|
#include <assert.h>
|
|
#include <stdio.h>
|
|
#include <atomic>
|
|
#include <cfloat>
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <stdexcept>
|
|
#include <stdint.h>
|
|
#include <string.h>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <array>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <fstream>
|
|
#include <set>
|
|
#include <iostream>
|
|
#include <climits>
|
|
#if defined(__linux__)
|
|
#include <asm/hwcap.h>
|
|
#include <sys/auxv.h>
|
|
#include <sys/types.h>
|
|
#include <sys/stat.h>
|
|
#include <unistd.h>
|
|
#elif defined(__APPLE__)
|
|
#include <string_view>
|
|
#include <sys/sysctl.h>
|
|
#include <sys/types.h>
|
|
#elif defined(_WIN32)
|
|
#include <windows.h>
|
|
#include <excpt.h>
|
|
#endif
|
|
|
|
#include "kleidiai.h"
|
|
|
|
#include "ggml-cpu.h"
|
|
#include "ggml-impl.h"
|
|
#include "ggml-backend-impl.h"
|
|
#include "ggml-threading.h"
|
|
#include "traits.h"
|
|
|
|
#include "kernels.h"
|
|
|
|
#include "kai_common.h"
|
|
|
|
#define GGML_COMMON_DECL_CPP
|
|
#include "ggml-common.h"
|
|
|
|
static constexpr int GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
|
|
static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC = 0x4b4c4149; // "KLAI"
|
|
static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION = 1;
|
|
static constexpr size_t GGML_KLEIDIAI_PACK_ALIGN = 64;
|
|
|
|
struct ggml_kleidiai_context {
|
|
cpu_feature features;
|
|
ggml_kleidiai_kernels * kernels_q4;
|
|
ggml_kleidiai_kernels * kernels_q8;
|
|
int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
|
|
int thread_hint; // <= 0 means “no hint”
|
|
} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
|
|
|
|
static const char* cpu_feature_to_string(cpu_feature f) {
|
|
if (f == CPU_FEATURE_NONE) {
|
|
return "NONE";
|
|
} else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
return "SME";
|
|
} else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {
|
|
return "SVE";
|
|
}
|
|
else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {
|
|
return "I8MM";
|
|
} else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {
|
|
return "DOTPROD";
|
|
}
|
|
else {
|
|
return "UNKNOWN";
|
|
}
|
|
}
|
|
|
|
static size_t detect_num_smcus() {
|
|
if (!ggml_cpu_has_sme()) {
|
|
return 0;
|
|
}
|
|
|
|
#if defined(__linux__) && defined(__aarch64__)
|
|
// Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
|
|
size_t num_private = 0;
|
|
std::set<uint32_t> shared_ids;
|
|
|
|
for (size_t cpu = 0;; ++cpu) {
|
|
const std::string path =
|
|
"/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
|
|
"/regs/identification/smidr_el1";
|
|
|
|
std::ifstream file(path);
|
|
if (!file.is_open()) {
|
|
break;
|
|
}
|
|
|
|
uint64_t smidr = 0;
|
|
if (!(file >> std::hex >> smidr)) {
|
|
continue;
|
|
}
|
|
|
|
// Arm ARM: SMIDR_EL1
|
|
const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
|
|
// Build an "affinity-like" identifier for shared SMCUs.
|
|
// Keep the original packing logic, but isolate it here.
|
|
const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
|
|
|
|
switch (sh) {
|
|
case 0b10: // private SMCU
|
|
++num_private;
|
|
break;
|
|
case 0b11: // shared SMCU
|
|
shared_ids.emplace(id);
|
|
break;
|
|
case 0b00:
|
|
// Ambiguous / implementation-defined. Be conservative:
|
|
// treat id==0 as private, otherwise as shared.
|
|
if (id == 0) ++num_private;
|
|
else shared_ids.emplace(id);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
return num_private + shared_ids.size();
|
|
|
|
#elif defined(__APPLE__) && defined(__aarch64__)
|
|
// table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.
|
|
char chip_name[256] = {};
|
|
size_t size = sizeof(chip_name);
|
|
|
|
if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
|
|
const std::string brand(chip_name);
|
|
|
|
struct ModelSMCU { const char *match; size_t smcus; };
|
|
static const ModelSMCU table[] = {
|
|
{ "M4 Ultra", 2 },
|
|
{ "M4 Max", 2 },
|
|
{ "M4 Pro", 2 },
|
|
{ "M4", 1 },
|
|
};
|
|
|
|
for (const auto &e : table) {
|
|
if (brand.find(e.match) != std::string::npos) {
|
|
return e.smcus;
|
|
}
|
|
}
|
|
}
|
|
return 1;
|
|
|
|
#else
|
|
return 1;
|
|
#endif
|
|
}
|
|
|
|
static int parse_uint_env(const char *s, const char *name, bool *ok) {
|
|
if (!s) { *ok = false; return 0; }
|
|
char *end = nullptr;
|
|
long v = strtol(s, &end, 10);
|
|
if (end == s || *end != '\0') {
|
|
GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
|
|
*ok = false;
|
|
return 0;
|
|
}
|
|
if (v < 0 || v > INT_MAX) {
|
|
GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
|
|
*ok = false;
|
|
return 0;
|
|
}
|
|
*ok = true;
|
|
return (int)v;
|
|
}
|
|
|
|
static void init_kleidiai_context(void) {
|
|
ggml_critical_section_start();
|
|
static bool initialized = false;
|
|
|
|
if (!initialized) {
|
|
initialized = true;
|
|
|
|
const char *env_sme = getenv("GGML_KLEIDIAI_SME");
|
|
const char *env_threads = getenv("GGML_TOTAL_THREADS");
|
|
|
|
const bool cpu_has_sme = ggml_cpu_has_sme();
|
|
size_t detected_smcus = 0;
|
|
|
|
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
|
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
|
((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
|
|
|
if (env_threads) {
|
|
bool ok = false;
|
|
int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
|
|
if (ok && hint > 0) {
|
|
ctx.thread_hint = hint;
|
|
}
|
|
}
|
|
|
|
// SME policy:
|
|
// - If CPU doesn't support SME: SME always off.
|
|
// - Else:
|
|
// - env unset => auto-detect cores; enable if detected > 0.
|
|
// - env=0 => force off.
|
|
// - env>0 => force N cores (skip detection).
|
|
int sme_cores = 0;
|
|
bool sme_env_ok = false;
|
|
bool sme_env_set = (env_sme != nullptr);
|
|
|
|
if (!cpu_has_sme) {
|
|
if (sme_env_set) {
|
|
bool ok = false;
|
|
int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
|
|
if (ok && req > 0) {
|
|
GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
|
|
}
|
|
}
|
|
sme_cores = 0;
|
|
} else {
|
|
if (sme_env_set) {
|
|
bool ok = false;
|
|
int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
|
|
sme_env_ok = ok;
|
|
|
|
if (!ok) {
|
|
GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
|
|
detected_smcus = detect_num_smcus();
|
|
sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
|
|
} else if (v == 0) {
|
|
sme_cores = 0;
|
|
} else {
|
|
sme_cores = v;
|
|
}
|
|
} else {
|
|
detected_smcus = detect_num_smcus();
|
|
sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
|
|
}
|
|
|
|
if (!sme_env_set && sme_cores == 0) {
|
|
GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
|
|
}
|
|
|
|
if (sme_cores > 0) {
|
|
ctx.features |= CPU_FEATURE_SME;
|
|
}
|
|
}
|
|
|
|
// Kernel selection
|
|
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
|
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
|
|
|
|
if (!ctx.kernels_q4) {
|
|
GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
|
|
} else {
|
|
GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
|
|
}
|
|
|
|
if (!ctx.kernels_q8) {
|
|
GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
|
|
} else {
|
|
GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
|
|
}
|
|
|
|
ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
|
|
|
|
if (ctx.features & CPU_FEATURE_SME) {
|
|
if (sme_env_set && sme_env_ok && sme_cores > 0) {
|
|
GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
|
|
} else {
|
|
GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
|
|
}
|
|
} else {
|
|
GGML_LOG_INFO("kleidiai: SME disabled\n");
|
|
}
|
|
}
|
|
|
|
ggml_critical_section_end();
|
|
}
|
|
|
|
static inline int kleidiai_sme_thread_cap() {
|
|
return ctx.sme_thread_cap;
|
|
}
|
|
|
|
static inline size_t align_up(size_t value, size_t alignment) {
|
|
if (alignment == 0) {
|
|
return value;
|
|
}
|
|
const size_t remainder = value % alignment;
|
|
return remainder == 0 ? value : value + (alignment - remainder);
|
|
}
|
|
|
|
static inline bool kleidiai_pack_fallback_allowed() {
|
|
if (ctx.sme_thread_cap <= 0) {
|
|
return false;
|
|
}
|
|
if (ctx.thread_hint <= 0) {
|
|
return true;
|
|
}
|
|
return ctx.thread_hint > ctx.sme_thread_cap;
|
|
}
|
|
|
|
struct kleidiai_weight_header {
|
|
uint32_t magic;
|
|
uint16_t version;
|
|
uint16_t slot_count;
|
|
uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
|
|
uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
|
|
};
|
|
|
|
static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
|
|
return reinterpret_cast<kleidiai_weight_header *>(data);
|
|
}
|
|
|
|
static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
|
|
return reinterpret_cast<const kleidiai_weight_header *>(data);
|
|
}
|
|
|
|
static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
|
|
if (!header) {
|
|
return false;
|
|
}
|
|
if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
|
|
return false;
|
|
}
|
|
if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
|
|
if (!kleidiai_is_weight_header_valid(header)) {
|
|
return nullptr;
|
|
}
|
|
if (slot < 0 || slot >= header->slot_count) {
|
|
return nullptr;
|
|
}
|
|
return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];
|
|
}
|
|
|
|
static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
|
|
if (!kleidiai_is_weight_header_valid(header)) {
|
|
return nullptr;
|
|
}
|
|
if (slot < 0 || slot >= header->slot_count) {
|
|
return nullptr;
|
|
}
|
|
return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];
|
|
}
|
|
|
|
static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
|
|
return ctx.kernels_q4;
|
|
}
|
|
|
|
static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
|
|
return ctx.kernels_q8;
|
|
}
|
|
|
|
template <typename SelectFallback>
|
|
static int kleidiai_collect_kernel_chain_common(
|
|
ggml_kleidiai_kernels * primary,
|
|
cpu_feature features,
|
|
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,
|
|
SelectFallback select_fallback) {
|
|
int count = 0;
|
|
if (!primary) {
|
|
return 0;
|
|
}
|
|
out[count++] = primary;
|
|
|
|
if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);
|
|
if (fallback_mask != CPU_FEATURE_NONE) {
|
|
ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
|
|
if (fallback && fallback != primary &&
|
|
fallback->lhs_type == primary->lhs_type &&
|
|
fallback->rhs_type == primary->rhs_type &&
|
|
fallback->op_type == primary->op_type) {
|
|
out[count++] = fallback;
|
|
}
|
|
}
|
|
}
|
|
|
|
return count;
|
|
}
|
|
|
|
static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
|
|
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
|
|
ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
|
|
return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
|
|
[&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
|
|
}
|
|
|
|
static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
|
|
ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
|
|
return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
|
|
[&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
|
|
}
|
|
|
|
static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
|
|
ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
|
|
return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
|
|
[&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
|
|
}
|
|
|
|
static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
|
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
|
return tensor->ne[dim];
|
|
}
|
|
|
|
namespace ggml::cpu::kleidiai {
|
|
|
|
static size_t round_down(size_t x, size_t y) {
|
|
return y == 0 ? x : x - (x % y);
|
|
}
|
|
|
|
static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
|
|
size_t src_stride = rhs_stride / sizeof(uint16_t);
|
|
size_t dst_stride = n;
|
|
|
|
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
|
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
|
|
uint16_t v = *(src + k_idx + n_idx * src_stride);
|
|
*(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
|
|
}
|
|
}
|
|
}
|
|
|
|
class tensor_traits : public ggml::cpu::tensor_traits {
|
|
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
if (op->op != GGML_OP_MUL_MAT) {
|
|
return false;
|
|
}
|
|
|
|
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
|
|
if (slot_count == 0) {
|
|
return false;
|
|
}
|
|
|
|
const bool is_gemv = op->src[1]->ne[1] == 1;
|
|
const size_t k = op->src[0]->ne[0];
|
|
const size_t n = op->src[0]->ne[1];
|
|
const size_t m = op->src[1]->ne[1];
|
|
|
|
if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
|
|
const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
|
|
|
|
size_t cursor = 0;
|
|
bool any_slot = false;
|
|
|
|
for (int slot = 0; slot < slot_count; ++slot) {
|
|
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
|
|
if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
|
|
return false;
|
|
}
|
|
|
|
const size_t mr = kernel->get_mr();
|
|
const size_t kr = kernel->get_kr();
|
|
const size_t sr = kernel->get_sr();
|
|
|
|
const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
|
|
|
|
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
cursor += packed;
|
|
any_slot = true;
|
|
}
|
|
|
|
if (!any_slot) {
|
|
return false;
|
|
}
|
|
|
|
size = cursor;
|
|
return true;
|
|
}
|
|
|
|
if (op->src[0]->type == GGML_TYPE_F16) {
|
|
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
|
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
|
|
GGML_ASSERT(rhs_batch_size0 > 0);
|
|
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
|
|
|
size_t cursor = 0;
|
|
bool any_slot = false;
|
|
|
|
for (int slot = 0; slot < slot_count; ++slot) {
|
|
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
|
|
return false;
|
|
}
|
|
|
|
const size_t mr = kernel->get_mr();
|
|
const size_t kr = kernel->get_kr();
|
|
const size_t sr = kernel->get_sr();
|
|
|
|
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
|
|
any_slot = true;
|
|
}
|
|
|
|
for (int slot = 0; slot < slot_count; ++slot) {
|
|
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
if (!kernel || !kernels->rhs_info.packed_size_ex) {
|
|
return false;
|
|
}
|
|
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
|
|
}
|
|
|
|
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
cursor += k * n * sizeof(float);
|
|
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
cursor += n * sizeof(float);
|
|
|
|
if (!any_slot) {
|
|
return false;
|
|
}
|
|
|
|
size = cursor;
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
|
if (dst->op == GGML_OP_MUL_MAT) {
|
|
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
|
|
return compute_forward_qx(params, dst);
|
|
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
return compute_forward_fp16(params, dst);
|
|
}
|
|
} else if (dst->op == GGML_OP_GET_ROWS) {
|
|
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
|
|
return compute_forward_get_rows(params, dst);
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const ggml_tensor * src1 = dst->src[1];
|
|
|
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
if (!kernels) {
|
|
return false;
|
|
}
|
|
|
|
const bool is_gemv = src1->ne[1] == 1;
|
|
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
GGML_ASSERT(kernel);
|
|
if (!kernels->rhs_info.pack_func_ex ||
|
|
!kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
|
|
return false;
|
|
}
|
|
|
|
const int nth = params->nth;
|
|
const int ith = params->ith;
|
|
|
|
const int64_t lhs_batch_size0 = ne12;
|
|
const int64_t rhs_batch_size0 = ne02;
|
|
const int64_t batch_size = lhs_batch_size0;
|
|
|
|
GGML_ASSERT(rhs_batch_size0 > 0);
|
|
GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
|
|
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
|
|
|
const int64_t m_group = ne11;
|
|
const int64_t m = m_group;
|
|
const int64_t n = ne01;
|
|
const int64_t k = ne00;
|
|
|
|
const size_t lhs_stride = src1->nb[1];
|
|
const size_t rhs_stride = src0->nb[1];
|
|
const size_t dst_stride = dst->nb[1];
|
|
|
|
const int64_t mr = (int64_t) kernel->get_mr();
|
|
const int64_t nr = (int64_t) kernel->get_nr();
|
|
const int64_t kr = (int64_t) kernel->get_kr();
|
|
const int64_t sr = (int64_t) kernel->get_sr();
|
|
|
|
const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
|
|
const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
|
|
const size_t kxn_size = k * n * sizeof(float);
|
|
const size_t bias_size = n * sizeof(float);
|
|
|
|
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
|
GGML_ASSERT(wsize_required <= params->wsize);
|
|
|
|
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
|
|
uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
|
|
uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
|
|
uint8_t * bias = rhs_kxn + kxn_size;
|
|
|
|
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
|
const int64_t rhs_batch_idx = batch_idx / r;
|
|
const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
|
|
uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
|
|
|
|
// LHS packing (threaded over m, honoring mr alignment and KV groups)
|
|
{
|
|
const int64_t m_roundup_mr = kai_roundup(m, mr);
|
|
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
|
|
|
|
if (ith < num_threads) {
|
|
const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
|
|
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
|
|
|
|
const int64_t m_start = ith * num_m_per_thread0;
|
|
const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
|
|
|
// Base packed offset (aligned) and per-row stride in bytes
|
|
const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
|
|
const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
|
|
const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
|
|
|
|
int64_t remaining = m_count;
|
|
int64_t cur = m_start;
|
|
|
|
while (remaining > 0) {
|
|
const int64_t row_in_group = cur;
|
|
const int64_t avail = m_group - row_in_group;
|
|
const int64_t take = std::min(avail, remaining);
|
|
|
|
const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
|
|
const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
|
|
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
|
|
void * dst_ptr = lhs_packed + dst_off;
|
|
|
|
lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
|
|
|
cur += take;
|
|
remaining -= take;
|
|
}
|
|
}
|
|
}
|
|
|
|
// RHS packing (single thread), then synchronize
|
|
if (ith == 0) {
|
|
memset(bias, 0, (size_t)n * sizeof(float));
|
|
transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
|
|
reinterpret_cast<float *>(rhs_kxn),
|
|
reinterpret_cast<const uint16_t *>(rhs_batch_base),
|
|
rhs_stride);
|
|
|
|
kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
|
|
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
|
}
|
|
|
|
ggml_barrier(params->threadpool);
|
|
|
|
// Matmul (threaded over n)
|
|
{
|
|
const int64_t n_step = (int64_t) kernel->get_n_step();
|
|
int64_t num_threads_n = KAI_MIN(n / n_step, nth);
|
|
if (num_threads_n <= 0) {
|
|
num_threads_n = 1;
|
|
}
|
|
|
|
if (ith < num_threads_n) {
|
|
const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step);
|
|
const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
|
|
|
|
const int64_t n_start = ith * num_n_per_thread0;
|
|
const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
|
|
|
// LHS packed base at row 0 (consistent with packing above)
|
|
const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
|
|
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
|
|
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
|
|
|
|
const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
|
|
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
|
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
|
|
|
|
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
|
}
|
|
}
|
|
|
|
if (batch_idx != batch_size - 1) {
|
|
ggml_barrier(params->threadpool);
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const ggml_tensor * src1 = dst->src[1];
|
|
|
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
|
const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
|
|
const bool has_header = kleidiai_is_weight_header_valid(header);
|
|
const bool is_gemv = src1->ne[1] == 1;
|
|
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
|
|
|
|
auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
|
|
if (slot_index < 0 || slot_index >= slot_total) {
|
|
return nullptr;
|
|
}
|
|
if (has_header) {
|
|
if (slot_index < header->slot_count) {
|
|
size_out = static_cast<size_t>(header->sizes[slot_index]);
|
|
return kleidiai_weight_slot_ptr(header, slot_index);
|
|
}
|
|
return nullptr;
|
|
}
|
|
if (slot_index == 0) {
|
|
size_out = ggml_nbytes(src0);
|
|
return static_cast<const uint8_t *>(src0->data);
|
|
}
|
|
return nullptr;
|
|
};
|
|
|
|
struct runtime_slot {
|
|
int slot_index;
|
|
ggml_kleidiai_kernels * kernels;
|
|
kernel_info * kernel;
|
|
lhs_packing_info * lhs_info;
|
|
size_t mr;
|
|
size_t nr;
|
|
size_t kr;
|
|
size_t sr;
|
|
size_t n_step;
|
|
size_t lhs_packed_size;
|
|
size_t lhs_offset;
|
|
size_t n_offset;
|
|
size_t n_cols;
|
|
int assigned_threads;
|
|
int thread_begin;
|
|
int thread_end;
|
|
const uint8_t * rhs_base;
|
|
};
|
|
|
|
std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};
|
|
int runtime_count = 0;
|
|
|
|
for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
|
|
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
kernel_info * kinfo = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
|
|
!kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
|
|
continue;
|
|
}
|
|
|
|
size_t rhs_size = 0;
|
|
const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
|
|
if (!rhs_ptr || rhs_size == 0) {
|
|
continue;
|
|
}
|
|
|
|
runtime[runtime_count] = {
|
|
slot,
|
|
kernels,
|
|
kinfo,
|
|
linfo,
|
|
kinfo->get_mr(),
|
|
kinfo->get_nr(),
|
|
kinfo->get_kr(),
|
|
kinfo->get_sr(),
|
|
kinfo->get_n_step(),
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
rhs_ptr
|
|
};
|
|
++runtime_count;
|
|
}
|
|
|
|
if (runtime_count == 0) {
|
|
ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
if (!fallback) {
|
|
return false;
|
|
}
|
|
kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm;
|
|
lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
|
|
rhs_packing_info * rinfo = &fallback->rhs_info;
|
|
if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
|
|
!kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
|
|
!rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
|
|
return false;
|
|
}
|
|
kernel_chain[0] = fallback;
|
|
runtime[0] = {
|
|
0,
|
|
fallback,
|
|
kinfo,
|
|
linfo,
|
|
kinfo->get_mr(),
|
|
kinfo->get_nr(),
|
|
kinfo->get_kr(),
|
|
kinfo->get_sr(),
|
|
kinfo->get_n_step(),
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
nullptr
|
|
};
|
|
size_t rhs_size_fallback = 0;
|
|
const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
|
|
if (!rhs_base) {
|
|
rhs_base = static_cast<const uint8_t *>(src0->data);
|
|
}
|
|
runtime[0].rhs_base = rhs_base;
|
|
runtime_count = 1;
|
|
}
|
|
|
|
const int nth_total = params->nth > 0 ? params->nth : 1;
|
|
const int ith_total = params->ith;
|
|
|
|
int sme_slot = -1;
|
|
for (int i = 0; i < runtime_count; ++i) {
|
|
if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
sme_slot = i;
|
|
break;
|
|
}
|
|
}
|
|
|
|
const int sme_cap_limit = ctx.sme_thread_cap;
|
|
const bool use_hybrid = sme_cap_limit > 0 &&
|
|
runtime_count > 1 &&
|
|
nth_total > sme_cap_limit;
|
|
// Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
|
|
// If rows are small or average columns per thread are small, keep single-slot.
|
|
size_t min_cols_per_thread = 0;
|
|
if (runtime_count > 0 && nth_total > 0) {
|
|
min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);
|
|
}
|
|
const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
|
|
|
|
const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
|
|
|
|
if (!hybrid_enabled) {
|
|
int chosen_slot = 0;
|
|
if (too_small_for_hybrid && sme_slot != -1) {
|
|
chosen_slot = sme_slot;
|
|
} else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
|
|
chosen_slot = 1;
|
|
}
|
|
if (chosen_slot != 0 && chosen_slot < runtime_count) {
|
|
runtime[0] = runtime[chosen_slot];
|
|
}
|
|
runtime_count = runtime_count > 0 ? 1 : 0;
|
|
|
|
// Recompute SME slot based on the collapsed runtime[0]
|
|
sme_slot = -1;
|
|
if (runtime_count > 0 &&
|
|
(runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
|
sme_slot = 0;
|
|
}
|
|
}
|
|
|
|
int sme_cap = kleidiai_sme_thread_cap();
|
|
if (sme_cap < 0) {
|
|
sme_cap = nth_total;
|
|
}
|
|
sme_cap = std::min(sme_cap, nth_total);
|
|
|
|
int threads_remaining = nth_total;
|
|
if (sme_slot != -1) {
|
|
int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
|
|
runtime[sme_slot].assigned_threads = sme_threads;
|
|
threads_remaining -= sme_threads;
|
|
}
|
|
|
|
int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
|
|
int fallback_count = 0;
|
|
for (int i = 0; i < runtime_count; ++i) {
|
|
if (i == sme_slot) {
|
|
continue;
|
|
}
|
|
fallback_indices[fallback_count++] = i;
|
|
}
|
|
|
|
for (int fi = 0; fi < fallback_count; ++fi) {
|
|
if (threads_remaining <= 0) {
|
|
break;
|
|
}
|
|
const int slot_index = fallback_indices[fi];
|
|
const int slots_left = fallback_count - fi;
|
|
int share = (threads_remaining + slots_left - 1) / slots_left;
|
|
share = std::min(share, threads_remaining);
|
|
runtime[slot_index].assigned_threads = share;
|
|
threads_remaining -= share;
|
|
}
|
|
|
|
if (threads_remaining > 0) {
|
|
const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
|
|
runtime[fallback_slot].assigned_threads += threads_remaining;
|
|
threads_remaining = 0;
|
|
}
|
|
|
|
int thread_cursor = 0;
|
|
for (int i = 0; i < runtime_count; ++i) {
|
|
runtime[i].thread_begin = thread_cursor;
|
|
thread_cursor += runtime[i].assigned_threads;
|
|
runtime[i].thread_end = thread_cursor;
|
|
}
|
|
|
|
if (thread_cursor < nth_total && runtime_count > 0) {
|
|
runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
|
|
runtime[runtime_count - 1].thread_end = nth_total;
|
|
}
|
|
|
|
int local_slot = -1;
|
|
int local_ith = 0;
|
|
for (int i = 0; i < runtime_count; ++i) {
|
|
if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
|
|
local_slot = i;
|
|
local_ith = ith_total - runtime[i].thread_begin;
|
|
break;
|
|
}
|
|
}
|
|
if (local_slot == -1) {
|
|
return false;
|
|
}
|
|
|
|
const size_t k = ne00;
|
|
const size_t m = ne11;
|
|
const size_t n = ne01;
|
|
|
|
size_t cursor = 0;
|
|
for (int i = 0; i < runtime_count; ++i) {
|
|
const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
|
|
const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
|
|
runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
|
|
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
runtime[i].lhs_offset = cursor;
|
|
cursor += runtime[i].lhs_packed_size;
|
|
}
|
|
|
|
GGML_ASSERT(cursor <= params->wsize);
|
|
uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
|
|
|
|
size_t assigned_cols = 0;
|
|
uint64_t weighted_total = 0;
|
|
if (runtime_count > 1 && sme_slot != -1) {
|
|
for (int i = 0; i < runtime_count; ++i) {
|
|
const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
|
|
weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
|
|
}
|
|
}
|
|
for (int i = 0; i < runtime_count; ++i) {
|
|
runtime[i].n_offset = assigned_cols;
|
|
if (runtime[i].assigned_threads == 0) {
|
|
runtime[i].n_cols = 0;
|
|
continue;
|
|
}
|
|
const size_t remaining_cols = n - assigned_cols;
|
|
if (remaining_cols == 0) {
|
|
runtime[i].n_cols = 0;
|
|
continue;
|
|
}
|
|
const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
|
|
size_t target = 0;
|
|
if (weighted_total > 0) {
|
|
const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
|
|
target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
|
|
} else {
|
|
target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
|
|
}
|
|
target = std::min(target, remaining_cols);
|
|
size_t aligned = round_down(target, step);
|
|
if (aligned == 0 && remaining_cols >= step) {
|
|
aligned = step;
|
|
}
|
|
runtime[i].n_cols = aligned;
|
|
assigned_cols += aligned;
|
|
}
|
|
|
|
if (assigned_cols < n) {
|
|
for (int i = runtime_count - 1; i >= 0; --i) {
|
|
if (runtime[i].assigned_threads > 0) {
|
|
runtime[i].n_cols += n - assigned_cols;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
const size_t dst_stride = dst->nb[1];
|
|
|
|
for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
|
|
const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
|
|
uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
|
|
|
|
if (runtime[local_slot].assigned_threads > 0) {
|
|
runtime_slot & slot = runtime[local_slot];
|
|
const ggml_type slot_rhs_type = slot.kernels->rhs_type;
|
|
const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
|
|
const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
|
|
int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
|
|
max_threads = std::max<int64_t>(1, max_threads);
|
|
const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);
|
|
|
|
if (local_ith < use_threads) {
|
|
const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
|
|
const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
|
|
|
|
const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
|
|
const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
|
|
|
const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
|
|
const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
|
|
const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
|
|
|
|
int64_t remaining = m_count;
|
|
int64_t cur = m_start;
|
|
|
|
uint8_t * lhs_packed = scratch + slot.lhs_offset;
|
|
while (remaining > 0) {
|
|
const int64_t row_in_group = cur;
|
|
const int64_t avail = (int64_t)m - row_in_group;
|
|
const int64_t take = std::min(avail, remaining);
|
|
|
|
const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
|
|
const void * src_ptr = lhs_batch_base + src_off;
|
|
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
|
|
void * dst_ptr = lhs_packed + dst_off;
|
|
|
|
slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
|
|
|
|
cur += take;
|
|
remaining -= take;
|
|
}
|
|
}
|
|
}
|
|
|
|
ggml_barrier(params->threadpool);
|
|
|
|
runtime_slot & slot = runtime[local_slot];
|
|
if (slot.n_cols > 0 && slot.assigned_threads > 0) {
|
|
int64_t active_threads = slot.assigned_threads;
|
|
const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
|
|
if (max_threads > 0) {
|
|
active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));
|
|
}
|
|
active_threads = std::max<int64_t>(1, active_threads);
|
|
|
|
if (local_ith < active_threads) {
|
|
const size_t step = slot.n_step ? slot.n_step : 1;
|
|
const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
|
|
const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
|
|
const size_t local_start = (size_t)local_ith * chunk0;
|
|
const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
|
|
|
|
if (cols > 0) {
|
|
const ggml_type slot_rhs_type = slot.kernels->rhs_type;
|
|
const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
|
|
const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
|
|
const size_t global_start = slot.n_offset + local_start;
|
|
const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
|
|
const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
|
|
const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride);
|
|
|
|
const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
|
|
const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
|
|
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
|
|
|
|
slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
|
|
lhs_ptr,
|
|
rhs_ptr,
|
|
dst_ptr,
|
|
dst_stride,
|
|
sizeof(float),
|
|
-FLT_MAX,
|
|
FLT_MAX);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (batch_idx != ne12 - 1) {
|
|
ggml_barrier(params->threadpool);
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const ggml_tensor * src1 = dst->src[1];
|
|
|
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
|
const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
|
|
const bool has_header = kleidiai_is_weight_header_valid(header);
|
|
|
|
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
|
|
const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
|
|
: kleidiai_collect_q4_chain(kernel_chain);
|
|
|
|
ggml_kleidiai_kernels * kernels = nullptr;
|
|
const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);
|
|
|
|
if (has_header && chain_count > 0) {
|
|
int select_slot = 0;
|
|
if (select_slot >= header->slot_count) {
|
|
select_slot = header->slot_count - 1;
|
|
}
|
|
if (select_slot >= 0 && select_slot < chain_count) {
|
|
kernels = kernel_chain[select_slot];
|
|
const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
|
|
if (slot_ptr) {
|
|
packed_base = slot_ptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!kernels && chain_count > 0) {
|
|
kernels = kernel_chain[0];
|
|
if (has_header) {
|
|
const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
|
|
if (slot_ptr) {
|
|
packed_base = slot_ptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!kernels) {
|
|
return false;
|
|
}
|
|
|
|
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
|
kernel_info * kernel = &kernels->gemm;
|
|
if (!rhs_info->to_float || !kernel->get_nr) {
|
|
return false;
|
|
}
|
|
|
|
const int64_t nc = ne00;
|
|
const int64_t nr = ggml_nelements(src1);
|
|
|
|
const ggml_type rhs_type = kernels->rhs_type;
|
|
size_t block_len = 0;
|
|
size_t num_bytes_multiplier = 0;
|
|
if (rhs_type == GGML_TYPE_Q4_0) {
|
|
block_len = QK4_0;
|
|
num_bytes_multiplier = sizeof(uint16_t);
|
|
} else if (rhs_type == GGML_TYPE_Q8_0) {
|
|
block_len = QK8_0;
|
|
num_bytes_multiplier = sizeof(float);
|
|
} else {
|
|
return false;
|
|
}
|
|
|
|
const size_t block_rows = kernel->get_nr();
|
|
const size_t kr = kernel->get_kr();
|
|
|
|
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
|
|
|
|
const int ith = params->ith;
|
|
const int nth = params->nth;
|
|
|
|
const int dr = (nr + nth - 1) / nth;
|
|
const int ir0 = dr * ith;
|
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
|
|
for (int64_t i = ir0; i < ir1; ++i) {
|
|
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
int64_t row_idx = ((const int32_t *)src1->data)[i];
|
|
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
|
|
|
float *out = (float *)((char *)dst->data + i * nb1);
|
|
rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
public:
|
|
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
|
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
|
|
const size_t n = tensor->ne[1];
|
|
const size_t k = tensor->ne[0];
|
|
|
|
kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
|
|
if (!header) {
|
|
return -1;
|
|
}
|
|
|
|
header->magic = GGML_KLEIDIAI_PACK_MAGIC;
|
|
header->version = GGML_KLEIDIAI_PACK_VERSION;
|
|
header->slot_count = 0;
|
|
|
|
uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);
|
|
size_t cursor = sizeof(kleidiai_weight_header);
|
|
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
|
|
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
|
|
const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
|
|
: kleidiai_collect_q4_chain(kernel_chain);
|
|
const bool allow_fallback = kleidiai_pack_fallback_allowed();
|
|
|
|
std::vector<int8_t> qdata;
|
|
std::vector<float> scales;
|
|
|
|
if (want_q8 && slot_total > 0) {
|
|
qdata.resize(n * k, 0);
|
|
scales.resize(n, 0.0f);
|
|
|
|
const size_t row_stride = tensor->nb[1];
|
|
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
|
|
|
|
for (size_t row = 0; row < n; ++row) {
|
|
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
|
|
static_cast<const uint8_t *>(data) + row * row_stride);
|
|
|
|
float max_abs = 0.0f;
|
|
for (size_t block = 0; block < k_blocks; ++block) {
|
|
const block_q8_0 & blk = row_blocks[block];
|
|
const float d = GGML_FP16_TO_FP32(blk.d);
|
|
for (size_t l = 0; l < QK8_0; ++l) {
|
|
const size_t linear_idx = block * QK8_0 + l;
|
|
if (linear_idx >= k) {
|
|
break;
|
|
}
|
|
const float value = d * static_cast<float>(blk.qs[l]);
|
|
max_abs = std::max(max_abs, std::fabs(value));
|
|
}
|
|
}
|
|
|
|
float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
|
|
scales[row] = scale;
|
|
const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
|
|
|
|
for (size_t block = 0; block < k_blocks; ++block) {
|
|
const block_q8_0 & blk = row_blocks[block];
|
|
const float d = GGML_FP16_TO_FP32(blk.d);
|
|
for (size_t l = 0; l < QK8_0; ++l) {
|
|
const size_t linear_idx = block * QK8_0 + l;
|
|
if (linear_idx >= k) {
|
|
break;
|
|
}
|
|
const float value = d * static_cast<float>(blk.qs[l]);
|
|
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
|
|
q = std::clamp(q, -127, 127);
|
|
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
|
|
if (!allow_fallback && slot > 0) {
|
|
break;
|
|
}
|
|
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
kernel_info * kernel = &kernels->gemm;
|
|
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
|
if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
|
|
continue;
|
|
}
|
|
|
|
const size_t nr = kernel->get_nr();
|
|
const size_t kr = kernel->get_kr();
|
|
const size_t sr = kernel->get_sr();
|
|
const ggml_type rhs_type = kernels->rhs_type;
|
|
const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
|
|
rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
|
|
if (block_len == 0) {
|
|
continue;
|
|
}
|
|
|
|
const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
|
|
const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
|
|
uint8_t * dst_ptr = base_ptr + aligned_cursor;
|
|
|
|
if (rhs_type == GGML_TYPE_Q4_0) {
|
|
struct kai_rhs_pack_qs4cxs1s0_param params;
|
|
params.lhs_zero_point = 1;
|
|
params.rhs_zero_point = 8;
|
|
rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
|
|
static_cast<const uint8_t *>(data), nullptr, nullptr,
|
|
dst_ptr, 0, ¶ms);
|
|
} else if (rhs_type == GGML_TYPE_Q8_0) {
|
|
struct kai_rhs_pack_qsi8cx_params params;
|
|
params.lhs_zero_point = 1;
|
|
params.scale_multiplier = 1.0f;
|
|
rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
|
|
qdata.data(), nullptr, scales.data(),
|
|
dst_ptr, 0, ¶ms);
|
|
} else {
|
|
continue;
|
|
}
|
|
|
|
header->offsets[header->slot_count] = aligned_cursor;
|
|
header->sizes[header->slot_count] = packed_size;
|
|
++header->slot_count;
|
|
|
|
cursor = aligned_cursor + packed_size;
|
|
}
|
|
|
|
if (header->slot_count == 0) {
|
|
header->magic = 0;
|
|
header->version = 0;
|
|
memcpy(tensor->data, data, data_size);
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
|
|
static tensor_traits traits;
|
|
return &traits;
|
|
}
|
|
} // namespace ggml::cpu::kleidiai
|
|
|
|
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
|
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
|
|
|
return GGML_STATUS_SUCCESS;
|
|
GGML_UNUSED(buffer);
|
|
}
|
|
|
|
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
|
const void * data, size_t offset, size_t size) {
|
|
GGML_ASSERT(offset == 0);
|
|
GGML_ASSERT(size == ggml_nbytes(tensor));
|
|
|
|
auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
|
|
auto OK = tensor_traits->repack(tensor, data, size);
|
|
|
|
GGML_ASSERT(OK == 0);
|
|
GGML_UNUSED(buffer);
|
|
}
|
|
|
|
static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
GGML_UNUSED(buft);
|
|
return "CPU_KLEIDIAI";
|
|
}
|
|
|
|
static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
|
|
|
if (buffer == nullptr) {
|
|
return nullptr;
|
|
}
|
|
|
|
buffer->buft = buft;
|
|
buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
|
|
buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
|
|
buffer->iface.get_tensor = nullptr;
|
|
buffer->iface.cpy_tensor = nullptr;
|
|
return buffer;
|
|
}
|
|
|
|
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
GGML_UNUSED(buft);
|
|
return TENSOR_ALIGNMENT;
|
|
}
|
|
|
|
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
|
GGML_UNUSED(buft);
|
|
|
|
if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
|
|
return ggml_nbytes(tensor);
|
|
}
|
|
|
|
const size_t n = tensor->ne[1];
|
|
const size_t k = tensor->ne[0];
|
|
|
|
size_t cursor = sizeof(kleidiai_weight_header);
|
|
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
|
|
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
|
|
const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
|
|
: kleidiai_collect_q4_chain(kernel_chain);
|
|
const bool allow_fallback = kleidiai_pack_fallback_allowed();
|
|
|
|
size_t slot_count = 0;
|
|
for (int slot = 0; slot < slot_total; ++slot) {
|
|
if (!allow_fallback && slot > 0) {
|
|
break;
|
|
}
|
|
ggml_kleidiai_kernels * kernels = kernel_chain[slot];
|
|
if (!kernels) {
|
|
continue;
|
|
}
|
|
kernel_info * kernel = &kernels->gemm;
|
|
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
|
if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
|
|
continue;
|
|
}
|
|
|
|
const ggml_type rhs_type = kernels->rhs_type;
|
|
const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
|
|
rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
|
|
if (block_len == 0) {
|
|
continue;
|
|
}
|
|
|
|
cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
|
|
cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
|
|
++slot_count;
|
|
}
|
|
|
|
if (slot_count == 0) {
|
|
return ggml_nbytes(tensor);
|
|
}
|
|
|
|
return std::max(cursor, ggml_nbytes(tensor));
|
|
}
|
|
|
|
namespace ggml::cpu::kleidiai {
|
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
|
|
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
|
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
|
|
op->src[0]->buffer &&
|
|
(ggml_n_dims(op->src[0]) == 2) &&
|
|
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
|
|
slot_total > 0) {
|
|
if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
|
|
return false;
|
|
}
|
|
if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
|
|
return false;
|
|
}
|
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
return false;
|
|
}
|
|
if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
|
|
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
|
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
|
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
} else {
|
|
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
|
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
|
|
const bool has_kernel = slot_total > 0;
|
|
if (has_kernel && op->src[1]->ne[1] > 1) {
|
|
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
|
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
|
return nullptr;
|
|
}
|
|
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
|
|
}
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
};
|
|
} // namespace ggml::cpu::kleidiai
|
|
|
|
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
|
|
static ggml::cpu::kleidiai::extra_buffer_type ctx;
|
|
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
|
|
/* .iface = */ {
|
|
/* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
|
|
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
|
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
|
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
|
/* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
|
|
/* .is_host = */ nullptr,
|
|
},
|
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
|
/* .context = */ &ctx,
|
|
};
|
|
|
|
init_kleidiai_context();
|
|
|
|
return &ggml_backend_cpu_buffer_type_kleidiai;
|
|
}
|