Merge branch 'master' into xsn/server_model_management_v1_2

This commit is contained in:
Xuan Son Nguyen 2025-11-26 16:21:57 +01:00
commit becc602612
52 changed files with 2578 additions and 659 deletions

View File

@ -2,10 +2,8 @@
# multiplie collaborators per item can be specified
/.devops/*.Dockerfile @ngxson
/.github/actions/ @slaren @CISC
/.github/actions/ @CISC
/.github/workflows/ @CISC
/.github/workflows/release.yml @slaren
/.github/workflows/winget.yml @slaren
/ci/ @ggerganov
/cmake/ @ggerganov
/common/CMakeLists.txt @ggerganov
@ -40,21 +38,14 @@
/examples/passkey/ @ggerganov
/examples/retrieval/ @ggerganov
/examples/save-load-state/ @ggerganov
/examples/simple-chat/ @slaren
/examples/simple/ @slaren
/examples/speculative-simple/ @ggerganov
/examples/speculative/ @ggerganov
/ggml/cmake/ @ggerganov
/ggml/include/ @ggerganov @slaren
/ggml/src/ggml-alloc.c @slaren
/ggml/src/ggml-backend* @slaren
/ggml/src/ggml-blas/ @slaren
/ggml/src/ggml-common.h @ggerganov @slaren
/ggml/src/ggml-cpu/ @ggerganov @slaren
/ggml/include/ @ggerganov
/ggml/src/ggml-common.h @ggerganov
/ggml/src/ggml-cpu/ @ggerganov
/ggml/src/ggml-cpu/spacemit/ @alex-spacemit
/ggml/src/ggml-cuda/common.cuh @slaren
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
@ -62,19 +53,19 @@
/ggml/src/ggml-cuda/fattn-wmma* @IMbackK
/ggml/src/ggml-hip/ @IMbackK
/ggml/src/ggml-cuda/vendors/hip.h @IMbackK
/ggml/src/ggml-impl.h @ggerganov @slaren
/ggml/src/ggml-impl.h @ggerganov
/ggml/src/ggml-metal/ @ggerganov
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
/ggml/src/ggml-opt.cpp @JohannesGaessler
/ggml/src/ggml-quants.* @ggerganov
/ggml/src/ggml-rpc/ @rgerganov
/ggml/src/ggml-threading.* @ggerganov @slaren
/ggml/src/ggml-threading.* @ggerganov
/ggml/src/ggml-vulkan/ @0cc4m
/ggml/src/ggml-webgpu/ @reeselevine
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
/ggml/src/ggml.c @ggerganov @slaren
/ggml/src/ggml.cpp @ggerganov @slaren
/ggml/src/ggml.c @ggerganov
/ggml/src/ggml.cpp @ggerganov
/ggml/src/gguf.cpp @JohannesGaessler @Green-Sky
/gguf-py/ @CISC
/media/ @ggerganov
@ -86,15 +77,11 @@
/src/llama-arch.* @CISC
/src/llama-chat.* @ngxson
/src/llama-graph.* @CISC
/src/llama-model-loader.* @slaren
/src/llama-model.* @CISC
/src/llama-vocab.* @CISC
/src/models/ @CISC
/tests/ @ggerganov
/tests/test-backend-ops.cpp @slaren
/tests/test-thread-safety.cpp @slaren
/tools/batched-bench/ @ggerganov
/tools/llama-bench/ @slaren
/tools/main/ @ggerganov
/tools/mtmd/ @ngxson
/tools/perplexity/ @ggerganov
@ -106,8 +93,6 @@
/tools/tokenize/ @ggerganov
/tools/tts/ @ggerganov
/vendor/ @ggerganov
/.clang-format @slaren
/.clang-tidy @slaren
/AUTHORS @ggerganov
/CMakeLists.txt @ggerganov
/CONTRIBUTING.md @ggerganov

View File

@ -1237,6 +1237,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
const auto sampler_names = string_split<std::string>(value, ';');
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
}
).set_sparam());
add_opt(common_arg(
@ -1266,6 +1267,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.sampling.temp = std::stof(value);
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP;
}
).set_sparam());
add_opt(common_arg(
@ -1273,6 +1275,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
[](common_params & params, int value) {
params.sampling.top_k = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
}
).set_sparam());
add_opt(common_arg(
@ -1280,6 +1283,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
[](common_params & params, const std::string & value) {
params.sampling.top_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
}
).set_sparam());
add_opt(common_arg(
@ -1287,6 +1291,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
[](common_params & params, const std::string & value) {
params.sampling.min_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
}
).set_sparam());
add_opt(common_arg(
@ -1301,6 +1306,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
[](common_params & params, const std::string & value) {
params.sampling.xtc_probability = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
}
).set_sparam());
add_opt(common_arg(
@ -1308,6 +1314,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
[](common_params & params, const std::string & value) {
params.sampling.xtc_threshold = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
}
).set_sparam());
add_opt(common_arg(
@ -1326,6 +1333,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
params.sampling.penalty_last_n = value;
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N;
}
).set_sparam());
add_opt(common_arg(
@ -1333,6 +1341,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
[](common_params & params, const std::string & value) {
params.sampling.penalty_repeat = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
}
).set_sparam());
add_opt(common_arg(
@ -1430,6 +1439,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
[](common_params & params, int value) {
params.sampling.mirostat = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT;
}
).set_sparam());
add_opt(common_arg(
@ -1437,6 +1447,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_eta = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
}
).set_sparam());
add_opt(common_arg(
@ -1444,6 +1455,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_tau = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
}
).set_sparam());
add_opt(common_arg(

View File

@ -8,6 +8,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include <algorithm>
#include <cinttypes>
@ -957,6 +958,58 @@ std::vector<common_file_info> fs_list(const std::string & path, bool include_dir
// Model utils
//
static inline void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {
const uint64_t config = sparams.user_sampling_config;
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
}
};
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
}
};
// Sampling sequence
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
char buf[512] = {0};
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
if (!sampler_names.empty()) {
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
}
}
}
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
@ -968,6 +1021,8 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
common_init_sampler_from_model(model, params.sampling);
const llama_vocab * vocab = llama_model_get_vocab(model);
auto cparams = common_context_params_to_llama(params);

View File

@ -138,6 +138,22 @@ struct common_grammar_trigger {
llama_token token = LLAMA_TOKEN_NULL;
};
enum common_params_sampling_config : uint64_t {
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
};
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@ -170,6 +186,8 @@ struct common_params_sampling {
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY

View File

@ -565,7 +565,7 @@ class ModelBase:
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
)
)
or not new_name.endswith(".weight")
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
):
data_qtype = gguf.GGMLQuantizationType.F32
@ -10061,6 +10061,25 @@ class LazyTorchTensor(gguf.LazyBase):
torch.uint8: np.uint8,
}
# only used when byteswapping data. Only correct size is needed
_dtype_byteswap_map: dict[torch.dtype, type] = {
torch.float64: np.float64,
torch.float32: np.float32,
torch.bfloat16: np.float16,
torch.float16: np.float16,
torch.int64: np.int64,
torch.uint64: np.uint64,
torch.int32: np.int32,
torch.uint32: np.uint32,
torch.int16: np.int16,
torch.uint16: np.uint16,
torch.int8: np.int8,
torch.uint8: np.uint8,
torch.bool: np.uint8,
torch.float8_e4m3fn: np.uint8,
torch.float8_e5m2: np.uint8,
}
# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
@ -10104,8 +10123,14 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[tensor.dtype]
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
numpy_dtype = cls._dtype_byteswap_map[dtype]
return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype]
shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
@ -10113,10 +10138,16 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[remote_tensor.dtype]
numpy_dtype = cls._dtype_byteswap_map[dtype]
shape = remote_tensor.shape
meta = cls.meta_with_dtype_and_shape(dtype, shape)
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape))
return cast(torch.Tensor, lazy)
@classmethod

View File

@ -242,7 +242,7 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(

View File

@ -530,6 +530,7 @@ extern "C" {
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_TOP_K,
GGML_OP_LEAKY_RELU,
GGML_OP_TRI,
GGML_OP_FILL,
@ -2258,18 +2259,25 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_sort_order order);
// similar to ggml_top_k but implemented as `argsort` + `view`
GGML_API struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
// top k elements per row
// note: the resulting top k indices are in no particular order
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step);
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
#define GGML_KQ_MASK_PAD 64
// q: [n_embd_k, n_batch, n_head, ne3 ]

View File

@ -42,6 +42,7 @@
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_fill_scalar.h>
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <aclnnop/aclnn_ger.h>
#include <aclnnop/aclnn_group_norm.h>
#include <aclnnop/aclnn_grouped_matmul_v3.h>
#include <aclnnop/aclnn_gt_scalar.h>
@ -2206,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx,
}
/**
* @brief Initializes and caches sine/cosine positional encoding values
* (used in RoPE, Rotary Position Embedding) for attention layers.
* @brief Initializes and caches all intermediate tensors required for RoPE
* (Rotary Position Embedding), including support for Yarn, mRoPE,
* i-mRoPE, Neox repeat strategy, independent sectors, frequency factors
* and multi-section rotary groups.
*
* This function computes and caches the sin/cos values of
* θ = position * theta_scale for RoPE encoding. The cache is shared
* across attention layers, and only the first attention layer will
* trigger initialization. The cache includes repeated sin/cos values
* with different repeat methods depending on the @param is_neox flag.
* This function computes and caches the per-dimension θ coefficients used for
* Q/K rotary embedding. The cache is shared across layers, and recomputed only
* when any dependent parameter changes.
*
* Steps performed by this function:
* 1. Identify whether the target tensor belongs to Q/K in attention
* and restrict computation to the first layer only.
* 2. Initialize the theta scale array (arange power freq scaling).
* 3. Allocate sin/cos caches if the max prompt length increases.
* 4. Compute θ = position * theta_scale.
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled.
* The function now supports:
* - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor)
* - Per-dimension independent sector exponent rules (indep_sects + sections[])
* - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope)
* - Frequency factor division (src2)
* - Neox / normal repeat expansion modes
*
* @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation
* depends on the RoPE values (usually Qcur/Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values.
* @param freq_scale Frequency scaling factor, applied to theta scale.
* @param attn_factor Attention scaling factor, applied to sin/cos.
* @param is_neox Whether to use Neox-style repeat strategy
* (dim expansion vs repeat_interleave).
* @param ctx CANN backend context, containing memory pool,
* cached buffers, and runtime stream.
* @param dst Destination ggml_tensor whose computation
* depends on RoPE (typically Qcur or Kcur).
* @param corr_dims [low, high] Yarn correction range.
* @param ext_factor Yarn extrapolation strength. 0 = disabled.
* @param theta_scale Base multiplier for per-dimension θ exponent.
* @param freq_scale Global frequency scaling factor.
* @param attn_factor Optional scaling applied to sin/cos (if needed).
* @param is_neox Whether to use Neox-style dimension interleave.
* @param sections 4-way sector sizes for independent-section RoPE
* and multi-section mRoPE (t/h/w/e).
* @param mrope_used Whether to enable multi-section rotary embedding.
* @param is_imrope Whether to apply interleaved mRoPE rules.
* @param indep_sects Whether each dimension runs independent exponent
* resets based on @p sections.
*/
static void aclnn_cache_init(ggml_backend_cann_context & ctx,
static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
ggml_tensor * dst,
float * corr_dims,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox) {
bool is_neox,
int sections[4],
bool mrope_used,
bool is_imrope,
bool indep_sects) {
ggml_tensor * src0 = dst->src[0]; // input
ggml_tensor * src1 = dst->src[1]; // position
ggml_tensor * src2 = dst->src[2]; // freq_factors
if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor &&
ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale &&
ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) {
int64_t theta_scale_length = src0->ne[0] / 2;
int64_t position_length = dst->ne[2];
// TODO: check theta_scale_length and position_length.
if (src2 == nullptr && ctx.rope_cache.cached &&
ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor,
is_neox, indep_sects, mrope_used, is_imrope, sections)) {
// use cache.
return;
}
int64_t theta_scale_length = src0->ne[0] / 2;
// Step0: calculate tensor shape.
int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) };
size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float),
theta_scale_length * sizeof(float) };
GGML_ASSERT(src1->type == GGML_TYPE_I32);
int64_t position_length = src1->ne[0];
int64_t position_ne[] = { 1, 1, position_length, 1 };
size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 };
size_t theta_nb[GGML_MAX_DIMS];
theta_nb[0] = sizeof(float);
int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 };
size_t cache_nb[GGML_MAX_DIMS];
cache_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1];
}
// theta_scale arange, [0,1,...,ne00/2 - 1]
// Step1: Compute the coefficient of theta. During the cache_init process, aside from
// (1) multiplying by the position,
// (2) dividing by freq_factors,
// (3) computing the sine and cosine,
// the other parameters used in the computation generally do not change in most scenarios.
// Therefore, we can first compute this part of the result and then cache it.
// Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor.
acl_tensor_ptr acl_theta_scale_tensor;
// cache theta scale
if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
// theta_scale and freq_scale should not change during the current token inference process,
// so we can directly use == here instead of comparing the absolute difference.
ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) {
ctx.rope_cache.theta_scale_length = theta_scale_length;
bool theta_scale_updated = false;
if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale ||
ctx.rope_cache.indep_sects != indep_sects) {
theta_scale_updated = true;
if (ctx.rope_cache.theta_scale_exp_host != nullptr) {
free(ctx.rope_cache.theta_scale_exp_host);
}
ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float));
GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr);
if (!indep_sects) {
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
} else {
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) {
ctx.rope_cache.theta_scale_exp_host[i] = 1;
continue;
}
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
}
if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
@ -2285,18 +2328,23 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, 1);
}
float start = 0;
float step = 1;
float stop = theta_scale_length;
float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements);
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
bool yarn_ramp_tensor_updated = false;
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0) {
if (ext_factor != 0 &&
// TODO: check more parameter.
(ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
yarn_ramp_tensor_updated = true;
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
@ -2312,8 +2360,8 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(),
acl_yarn_ramp_tensor.get());
aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
@ -2335,24 +2383,83 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
}
// power
acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(),
acl_theta_scale_tensor.get());
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
if (ext_factor != 0) {
if (theta_scale_updated || yarn_ramp_tensor_updated) {
theta_scale_updated = true;
aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get());
} else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
}
} else {
// use cache
if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) {
theta_scale_updated = true;
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
}
}
// Nothing changed, use cache.
if (!theta_scale_updated) {
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
}
// Step 1.4: prepare select index if mrope
acl_tensor_ptr position_select_index_tensor;
if (mrope_used) {
if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] ||
ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] ||
ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) {
if (ctx.rope_cache.position_select_index_host != nullptr) {
free(ctx.rope_cache.position_select_index_host);
}
ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int));
GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr);
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
// t,h,w,e
for (int i = 0; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (is_imrope) { // qwen3vl apply interleaved mrope
if (sector % 3 == 1 && sector < 3 * sections[1]) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
ctx.rope_cache.position_select_index_host[i] = 0;
} else {
ctx.rope_cache.position_select_index_host[i] = 3;
}
} else {
if (sector >= sections[0] && sector < sec_w) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector >= sec_w && sector < sec_e) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector >= sec_e) {
ctx.rope_cache.position_select_index_host[i] = 3;
} else {
ctx.rope_cache.position_select_index_host[i] = 0;
}
}
}
if (ctx.rope_cache.position_select_index != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
}
position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32,
sizeof(int), theta_scale_ne, theta_scale_nb, 1);
}
// Step2: divide by freq_factors
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
// freq_factors
if (src2) {
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
void * freq_fac_res_ptr = freq_fac_res_allocator.get();
@ -2365,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
}
// Step3: prepare position_tensor
acl_tensor_ptr acl_position_tensor;
ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool());
if (mrope_used) {
// Step3.1: select current position;
// position :
// pos1: [[0, 1 ,2 ,3 ],
// pos2: [4, 5 ,6 ,7 ],
// pos3: [8, 9 ,10,11],
// pos4: [12,13,14,15] ]
//
// select index = [0, 1, 2, 2, 1, 0]
//
// selected_tensor:
// [[0, 1 ,2 ,3 ],
// [4, 5 ,6 ,7 ],
// [8, 9 ,10,11],
// [8, 9 ,10,11],
// [4, 5 ,6 ,7 ],
// [0, 1 ,2 ,3 ]]
//
// transpose, from [seq_len:dims] to [dims:seq_len]
// [0, 4, 8 ,8 ,4, 0],
// [1, 5, 9, 9, 5, 1],
// [2, 6, 10,10,6 ,2],
// [3, 7, 11,11,7 3 ]]
//
// multipy by theta_scale_tensor
// [theta_scale^0, theta_scale^1, ..., theta_scale ^ n]
int64_t mrope_position_ne[] = { position_length, 4 };
size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) };
acl_tensor_ptr mrope_position =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
mrope_position_ne, mrope_position_nb, 2);
// selected position tensor's shape is a transpose of cache tensor.
int64_t selected_position_ne[] = { position_length, theta_scale_length };
size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) };
mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float));
void * mrope_position_buffer = mrope_position_acllocator.get();
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(),
acl_position_tensor.get());
// transpose
int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 };
size_t transposed_nb[GGML_MAX_DIMS];
transposed_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1];
}
std::swap(transposed_ne[0], transposed_ne[2]);
std::swap(transposed_nb[0], transposed_nb[2]);
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS);
} else {
// auto bcast.
acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
position_ne, position_nb, GGML_MAX_DIMS);
}
// Step4: multiply by the position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// Step5: calculate sin cos.
// init sin_repeat && cos_repeat, only to accelerate first layer on each device
if (position_length > ctx.rope_cache.position_length) {
ctx.rope_cache.position_length = position_length;
@ -2381,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
}
// position
acl_tensor_ptr acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne,
position_nb, GGML_MAX_DIMS);
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// sin/cos
ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float));
void * sin_buffer = sin_allocator.get();
acl_tensor_ptr acl_sin_tensor =
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get());
ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float));
void * cos_buffer = cos_allocator.get();
acl_tensor_ptr acl_cos_tensor =
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get());
if (ext_factor != 0) {
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
// attn_factor
// Step 5: multiply by attn_factor
if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true);
aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
}
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 };
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
@ -2429,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat
// Step 6: repeat
if (is_neox) {
// [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn]
int64_t repeatsArray[] = { 1, 1, 1, 2 };
aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray);
aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray);
@ -2438,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
int64_t num_repeats = 2;
int64_t dim = 3;
int64_t output_size = theta_scale_length * num_repeats;
// [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn]
aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size);
aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size);
}
// Other layers use cache except first layer.
// Update cached value.
ctx.rope_cache.cached = true;
ctx.rope_cache.ext_factor = ext_factor;
ctx.rope_cache.theta_scale = theta_scale;
ctx.rope_cache.freq_scale = freq_scale;
ctx.rope_cache.attn_factor = attn_factor;
ctx.rope_cache.is_neox = is_neox;
ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox,
indep_sects, mrope_used, is_imrope, sections);
}
#ifdef __cplusplus
@ -2474,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
// const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
@ -2488,6 +2660,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
// TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0);
@ -2498,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (mrope_used) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
if (is_imrope || mrope_used) {
is_neox = true;
}
// init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox);
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS];
@ -2657,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
return;
#endif
// ggml_mode = 0 --> aclnn_model = 1
int64_t acl_mode = mode == 0 ? 1 : mode;
int64_t acl_mode = is_neox ? 0 : 1;
switch (src0->type) {
case GGML_TYPE_F32:
@ -3236,3 +3423,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
GGML_ABORT("Function is not implemented.");
}
}
static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // weight
ggml_tensor * src1 = dst->src[1]; // input
GGML_TENSOR_BINARY_OP_LOCALS
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t i02 = i2 / dps2;
const int64_t i03 = i3 / dps3;
const int64_t i12 = i2;
const int64_t i13 = i3;
acl_tensor_ptr accumulator =
ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
// The outer product needs to be accumulated in this dimension.
for (int64_t i1 = 0; i1 < ne11; i1++) {
acl_tensor_ptr acl_input = ggml_cann_create_tensor(
(char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src1->ne, src1->nb, 1);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(
(char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src0->ne, src0->nb, 1);
ggml_cann_pool_alloc output_allocator(ctx.pool());
void * output_buffer = output_allocator.alloc(ggml_nbytes(dst));
acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
float alpha_value = 1.0f;
aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
}
}
}
}
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
const enum ggml_type type = src0->type;
switch (type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
ggml_cann_out_prod_fp(ctx, dst);
break;
default:
GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD");
break;
}
}

View File

@ -1125,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
} while (0)
#endif // CANN_ACLNN_OPS
/**
* @brief Performs outer product operation on two ggml tensors using the CANN backend.
*
* @details This function computes the outer product of two input tensors (src0 and src1)
* and stores the result in the destination tensor. The outer product operation is defined as:
* dst[i,j,k,l] = sum_m (src0[i,m,k,l] * src1[j,m,k,l])
*
* The function supports multiple data types including F32, F16. For floating-point
* types, it uses batch matrix multiplication for efficient computation.
*
* The implementation handles 4D tensor broadcasting and batch processing automatically.
*
* @param ctx The CANN backend context for operation execution and memory management.
* @param dst The destination ggml_tensor where the outer product result will be stored.
* The input tensors are assumed to be `dst->src[0]` and `dst->src[1]`.
*
* @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation
*/
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst);

View File

@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache {
struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() {
if (theta_scale_cache != nullptr) {
if (theta_scale_cache) {
ACL_CHECK(aclrtFree(theta_scale_cache));
}
if (sin_cache != nullptr) {
if (sin_cache) {
ACL_CHECK(aclrtFree(sin_cache));
}
if (cos_cache != nullptr) {
if (cos_cache) {
ACL_CHECK(aclrtFree(cos_cache));
}
if (position_select_index) {
ACL_CHECK(aclrtFree(position_select_index));
}
if (theta_scale_exp_host) {
free(theta_scale_exp_host);
}
if(position_select_index_host) {
free(position_select_index_host);
}
}
bool equal(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
}
void set(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
this->theta_scale_length = theta_scale_length;
this->position_length = position_length;
this->ext_factor = ext_factor;
this->theta_scale = theta_scale;
this->freq_scale = freq_scale;
this->attn_factor = attn_factor;
this->is_neox = is_neox;
this->indep_sects = indep_sects;
this->mrope_used = mrope_used;
this->is_imrope = is_imrope;
this->sections[0] = sections[0];
this->sections[1] = sections[1];
this->sections[2] = sections[2];
this->sections[3] = sections[3];
}
// memory cache, prepare before inferencing.
void * theta_scale_cache = nullptr;
int64_t theta_scale_length = 0;
float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
// sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr;
void * cos_cache = nullptr;
int64_t position_length = 0;
// Properties to check before reusing the sincos cache
int64_t theta_scale_length = 0;
int64_t position_length = 0;
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
bool indep_sects = false;
bool mrope_used = false;
int sections[4] = { 0, 0, 0, 0 };
bool is_imrope = false;
};
struct ggml_cann_tensor_cache {

View File

@ -1886,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst);
break;
case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst);
break;
default:
return false;
}
@ -2477,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
return false;
}
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
if (op->src[0]->ne[0] > 896) {
return false;
}
@ -2563,6 +2559,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_OUT_PROD:
{
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
default:
return false;
}
}
case GGML_OP_CONV_TRANSPOSE_1D:
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
return (op->src[0]->ne[0] - 1) <= 255;

View File

@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
include(CheckCXXSourceCompiles)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}")
string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}")
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}")
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
set(ARM_FEATURE "HAVE_${feature}")
check_cxx_source_compiles(

View File

@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_argsort(params, tensor);
} break;
case GGML_OP_TOP_K:
{
ggml_compute_forward_top_k(params, tensor);
} break;
case GGML_OP_LEAKY_RELU:
{
ggml_compute_forward_leaky_relu(params, tensor);
@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
} break;
case GGML_OP_TOP_K:
{
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne10 = node->src[1]->ne[0]; // DK

View File

@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
// ggml_compute_forward_argsort
template<enum ggml_sort_order order>
struct argsort_cmp {
struct cmp_argsort {
const float * data;
bool operator()(int32_t a, int32_t b) const {
if constexpr (order == GGML_SORT_ORDER_ASC) {
@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
switch (order) {
case GGML_SORT_ORDER_ASC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
break;
case GGML_SORT_ORDER_DESC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
break;
default:
@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
}
}
// ggml_compute_forward_top_k
struct cmp_top_k {
const float * data;
bool operator()(int32_t a, int32_t b) const {
return data[a] > data[b];
}
};
static void ggml_compute_forward_top_k_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb0 == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_nrows(src0);
const int top_k = ne0;
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
for (int64_t i = ith; i < nr; i += nth) {
const float * src_data = (float *)((char *) src0->data + i*nb01);
for (int64_t j = 0; j < ne00; j++) {
tmp[j] = j;
}
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
std::copy(tmp, tmp + top_k, dst_data);
// emphasize that the order is not important
if (top_k > 1) {
std::swap(dst_data[0], dst_data[1]);
}
}
}
void ggml_compute_forward_top_k(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_top_k_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(

View File

@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -397,15 +397,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
}
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 8 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np= (n & ~(ggml_f16_step - 1));
int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
@ -474,14 +473,18 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
np = n;
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
const int np = n;
_Float16 hv = (_Float16)v;
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e16m8(n - i);
vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
__riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
}
#else
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@ -498,18 +501,14 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
#else
const int np = 0;
#endif
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
}
// xs and vs are byte strides of x and v

View File

@ -73,34 +73,7 @@ namespace ggml_cuda_mma {
static constexpr int I = I_;
static constexpr int J = J_;
#if defined(GGML_USE_HIP)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
#if defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
T x[ne] = {0};
@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif // defined(RDNA4)
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I * J / 32;
T x[ne] = {0};
@ -177,6 +149,34 @@ namespace ggml_cuda_mma {
return -1;
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#endif
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};
@ -437,7 +437,29 @@ namespace ggml_cuda_mma {
xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
} else if constexpr (std::is_same_v<T, int>) {
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
}else{
NO_DEVICE_CODE;
}
} else {
NO_DEVICE_CODE;
}
#else
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
@ -772,6 +794,36 @@ namespace ggml_cuda_mma {
acc[0],
0, 0, 0);
#endif // defined(CDNA3)
#elif defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;
#if defined(RDNA4)
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
@ -798,6 +850,7 @@ namespace ggml_cuda_mma {
acc[0],
0, 0, 0);
#endif // defined(CDNA3)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
@ -842,4 +895,31 @@ namespace ggml_cuda_mma {
mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
}
static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
#if defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
false
);
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif
}
}

View File

@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
if (amd_wmma_available(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return true;
}
}
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

File diff suppressed because it is too large Load Diff

View File

@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
return res;
}
// note: reuse the argsort kernel for top_k
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_TOP_K);
char base[256];
char name[256];
// note: the top_k kernel is always descending order
ggml_sort_order order = GGML_SORT_ORDER_DESC;
const char * order_str = "undefined";
switch (order) {
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_TOP_K);
char base[256];
char name[256];
ggml_sort_order order = GGML_SORT_ORDER_DESC;
const char * order_str = "undefined";
switch (order) {
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,

View File

@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_ARANGE:
return true;
case GGML_OP_FLASH_ATTN_EXT:

View File

@ -832,14 +832,19 @@ typedef struct {
} ggml_metal_kargs_leaky_relu;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
int32_t top_k;
} ggml_metal_kargs_argsort;
typedef struct {
@ -851,6 +856,11 @@ typedef struct {
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
int32_t top_k;
int32_t len;
} ggml_metal_kargs_argsort_merge;

View File

@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_argsort(ctx, idx);
} break;
case GGML_OP_TOP_K:
{
n_fuse = ggml_metal_op_top_k(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
{
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
@ -3686,6 +3690,11 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ nth,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx);
ggml_metal_kargs_argsort_merge args_merge = {
.ne00 = ne00,
.ne01 = ne01,
.ne02 = ne02,
.ne03 = ne03,
.nb00 = nb00,
.nb01 = nb01,
.nb02 = nb02,
.nb03 = nb03,
.len = len,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ ne00,
/*.len =*/ len,
};
// merges per row
@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
// bitonic sort requires the number of elements to be power of 2
int nth = 1;
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
// blocks per row
const int npr = (ne00 + nth - 1)/nth;
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
std::swap(bid_dst, bid_tmp);
}
const int top_k = ne0;
ggml_metal_kargs_argsort args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
};
if (npr > 1) {
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
}
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
int len = args.top_k;
while (len < args.ne0) {
ggml_metal_op_concurrency_reset(ctx);
// merges per row
const int nm = (args.ne0 + 2*len - 1) / (2*len);
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
ggml_metal_kargs_argsort_merge args_merge = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ args.ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
/*.len =*/ len,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
std::swap(bid_dst, bid_tmp);
len <<= 1;
}
return 1;
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

View File

@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);

View File

@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
{
res *= 2;
} break;
case GGML_OP_TOP_K:
{
res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
} break;
default:
break;
}

View File

@ -4670,8 +4670,9 @@ kernel void kernel_argsort_f32_i32(
ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
const int col = tpitg[0];
const int ib = tgpig[0] / args.ne01;
const int i00 = (tgpig[0]/args.ne01)*ntg.x;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32(
}
}
const int64_t i0 = ib*args.top_k;
// copy the result to dst without the padding
if (i00 + col < args.ne00) {
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
if (i0 + col < args.ne0 && col < args.top_k) {
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col];
}
@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32(
const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start
+ i01*args.ne00
+ i02*args.ne00*args.ne01
+ i03*args.ne00*args.ne01*args.ne02;
+ i01*args.ne0
+ i02*args.ne0*args.ne01
+ i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len;
dst += start
+ i01*args.ne00
+ i02*args.ne00*args.ne01
+ i03*args.ne00*args.ne01*args.ne02;
+ i01*args.top_k
+ i02*args.top_k*args.ne01
+ i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0
+ args.nb01*i01
@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32(
const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk;
const int k1 = min(k0 + chunk, total);
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
if (k0 >= args.top_k) {
return;
}
if (k0 >= total) {
return;

View File

@ -705,6 +705,7 @@ struct vk_device_struct {
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@ -1629,6 +1630,22 @@ class vk_perf_logger {
timings[name].push_back(time);
return;
}
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
const ggml_tensor * dst = node;
const ggml_tensor * q = node->src[0];
const ggml_tensor * k = node->src[1];
const ggml_tensor * v = node->src[2];
const ggml_tensor * m = node->src[3];
std::stringstream name;
name << ggml_op_name(node->op) <<
" dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " <<
" q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " <<
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
timings[name.str()].push_back(time);
return;
}
timings[ggml_op_name(node->op)].push_back(time);
}
private:
@ -2485,9 +2502,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
if (hsv >= 192) {
return 2;
} else if ((hsv | hsk) & 8) {
return 4;
} else {
return 8;
}
@ -2519,9 +2538,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if ((hsv | hsk) & 8) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
return {get_fa_scalar_num_large_rows(hsv), 64};
return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
} else {
return {get_fa_scalar_num_large_rows(hsv), 32};
return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
}
}
}
@ -3950,6 +3969,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
#define IM2COL(bda) \
@ -7724,7 +7745,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float);
@ -7855,7 +7876,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = get_fa_scalar_num_large_rows(HSV);
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@ -8439,6 +8460,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_sum_rows_f32;
}
return nullptr;
case GGML_OP_CUMSUM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cumsum_f32;
}
return nullptr;
case GGML_OP_ARGMAX:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argmax_f32;
@ -8803,6 +8829,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS:
case GGML_OP_CUMSUM:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
{
@ -10132,6 +10159,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
}
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
}
@ -11731,6 +11763,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SUM_ROWS:
ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CUMSUM:
ggml_vk_cumsum(ctx, compute_ctx, src0, node);
break;
case GGML_OP_MEAN:
ggml_vk_mean(ctx, compute_ctx, src0, node);
@ -13768,6 +13804,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_CUMSUM:
{
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
}
return false;
}
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@ -14418,6 +14463,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CUMSUM) {
tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_MEAN) {
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_ARGMAX) {

View File

@ -0,0 +1,69 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
shared FLOAT_TYPE last_sum;
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
uint subgroup_id = tid / SUBGROUP_SIZE;
if (tid == 0) {
last_sum = 0;
}
uint col = tid;
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
for (int i = 0; i < num_iter; ++i) {
FLOAT_TYPE v = 0;
if (col < p.n_cols) {
v = FLOAT_TYPE(data_a[src_idx + col]);
}
v = subgroupInclusiveAdd(v);
// Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
partial[subgroup_id] = v;
}
barrier();
for (int j = 0; j < subgroup_id; ++j) {
v += partial[j];
}
v += last_sum;
barrier();
if (tid == BLOCK_SIZE - 1) {
last_sum = v;
}
if (col < p.n_cols) {
data_d[dst_idx + col] = D_TYPE(v);
}
col += BLOCK_SIZE;
}
}

View File

@ -1,6 +1,7 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {

View File

@ -0,0 +1,25 @@
// vk_op_sum_rows_push_constants
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}

View File

@ -916,6 +916,7 @@ void process_shaders() {
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) {

View File

@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ARANGE",
"TIMESTEP_EMBEDDING",
"ARGSORT",
"TOP_K",
"LEAKY_RELU",
"TRI",
"FILL",
@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"top_k(x)",
"leaky_relu(x)",
"tri(x)",
"fill(x, c)",
@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll(
return result;
}
// ggml_arange
struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step) {
GGML_ASSERT(stop > start);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
ggml_set_op_params_f32(result, 0, start);
ggml_set_op_params_f32(result, 1, stop);
ggml_set_op_params_f32(result, 2, step);
result->op = GGML_OP_ARANGE;
return result;
}
// ggml_timestep_embedding
struct ggml_tensor * ggml_timestep_embedding(
@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort(
struct ggml_tensor * a,
enum ggml_sort_order order) {
GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
ggml_set_op_params_i32(result, 0, (int32_t) order);
@ -5149,6 +5130,24 @@ struct ggml_tensor * ggml_argsort(
return result;
}
// ggml_argsort_top_k
struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k) {
GGML_ASSERT(a->ne[0] >= k);
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
result = ggml_view_4d(ctx, result,
k, result->ne[1], result->ne[2], result->ne[3],
result->nb[1], result->nb[2], result->nb[3],
0);
return result;
}
// ggml_top_k
struct ggml_tensor * ggml_top_k(
@ -5157,12 +5156,32 @@ struct ggml_tensor * ggml_top_k(
int k) {
GGML_ASSERT(a->ne[0] >= k);
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
result = ggml_view_4d(ctx, result,
k, result->ne[1], result->ne[2], result->ne[3],
result->nb[1], result->nb[2], result->nb[3],
0);
result->op = GGML_OP_TOP_K;
result->src[0] = a;
return result;
}
// ggml_arange
struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step) {
GGML_ASSERT(stop > start);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
ggml_set_op_params_f32(result, 0, start);
ggml_set_op_params_f32(result, 1, stop);
ggml_set_op_params_f32(result, 2, step);
result->op = GGML_OP_ARANGE;
return result;
}

View File

@ -25,6 +25,20 @@ class Keys:
ALIGNMENT = "general.alignment"
FILE_TYPE = "general.file_type"
# Recommended Sampler Parameters
SAMPLING_SEQUENCE = "general.sampling.sequence"
SAMPLING_TOP_K = "general.sampling.top_k"
SAMPLING_TOP_P = "general.sampling.top_p"
SAMPLING_MIN_P = "general.sampling.min_p"
SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability"
SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold"
SAMPLING_TEMP = "general.sampling.temp"
SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n"
SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat"
SAMPLING_MIROSTAT = "general.sampling.mirostat"
SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau"
SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta"
# Authorship Metadata
NAME = "general.name"
AUTHOR = "general.author"

View File

@ -4,6 +4,7 @@ import logging
import os
import shutil
import struct
import sys
import tempfile
from dataclasses import dataclass
from enum import Enum, auto
@ -372,8 +373,10 @@ class GGUFWriter:
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
fp.seek(0)
@ -399,8 +402,10 @@ class GGUFWriter:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
file_id = -1
for i, tensors in enumerate(self.tensors):
@ -496,6 +501,42 @@ class GGUFWriter:
def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype)
def add_sampling_sequence(self, sequence: str) -> None:
self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
def add_sampling_top_k(self, top_k: int) -> None:
self.add_int32(Keys.General.SAMPLING_TOP_K, top_k)
def add_sampling_top_p(self, top_p: float) -> None:
self.add_float32(Keys.General.SAMPLING_TOP_P, top_p)
def add_sampling_min_p(self, min_p: float) -> None:
self.add_float32(Keys.General.SAMPLING_MIN_P, min_p)
def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
def add_sampling_temp(self, temp: float) -> None:
self.add_float32(Keys.General.SAMPLING_TEMP, temp)
def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
def add_sampling_mirostat(self, mirostat: int) -> None:
self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat)
def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)

View File

@ -17,6 +17,20 @@ logger = logging.getLogger("metadata")
@dataclass
class Metadata:
# Recommended Sampler Parameters to be written to GGUF KV Store
sampling_sequence: Optional[str] = None
sampling_top_k: Optional[int] = None
sampling_top_p: Optional[float] = None
sampling_min_p: Optional[float] = None
sampling_xtc_probability: Optional[float] = None
sampling_xtc_threshold: Optional[float] = None
sampling_temp: Optional[float] = None
sampling_penalty_last_n: Optional[int] = None
sampling_penalty_repeat: Optional[float] = None
sampling_mirostat: Optional[int] = None
sampling_mirostat_tau: Optional[float] = None
sampling_mirostat_eta: Optional[float] = None
# Authorship Metadata to be written to GGUF KV Store
name: Optional[str] = None
author: Optional[str] = None
@ -54,15 +68,43 @@ class Metadata:
model_card = Metadata.load_model_card(model_path)
hf_params = Metadata.load_hf_parameters(model_path)
gen_config = Metadata.load_generation_config(model_path)
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
# heuristics
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
if gen_config:
metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence)
metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k)
metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p)
metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p)
metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold)
metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp)
metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n)
metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat)
metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat)
metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau)
metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta)
# Metadata Override File Provided
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence)
metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k)
metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p)
metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p)
metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold)
metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp)
metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n)
metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat)
metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat)
metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau)
metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta)
metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
@ -172,6 +214,23 @@ class Metadata:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
if model_path is None or not model_path.is_dir():
return {}
generation_config_path = model_path / "generation_config.json"
if not generation_config_path.is_file():
return {}
try:
with open(generation_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
# not all models have valid generation_config.json
return {}
@staticmethod
def id_to_title(string):
# Convert capitalization into title form unless acronym or version number
@ -546,6 +605,32 @@ class Metadata:
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
assert self.name is not None
if self.sampling_sequence is not None:
gguf_writer.add_sampling_sequence(self.sampling_sequence)
if self.sampling_top_k is not None:
gguf_writer.add_sampling_top_k(self.sampling_top_k)
if self.sampling_top_p is not None:
gguf_writer.add_sampling_top_p(self.sampling_top_p)
if self.sampling_min_p is not None:
gguf_writer.add_sampling_min_p(self.sampling_min_p)
if self.sampling_xtc_probability is not None:
gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
if self.sampling_xtc_threshold is not None:
gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
if self.sampling_temp is not None:
gguf_writer.add_sampling_temp(self.sampling_temp)
if self.sampling_penalty_last_n is not None:
gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
if self.sampling_penalty_repeat is not None:
gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
if self.sampling_mirostat is not None:
gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
if self.sampling_mirostat_tau is not None:
gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
if self.sampling_mirostat_eta is not None:
gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
gguf_writer.add_name(self.name)
if self.author is not None:

View File

@ -246,6 +246,21 @@ extern "C" {
LLAMA_KV_OVERRIDE_TYPE_STR,
};
enum llama_model_meta_key {
LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE,
LLAMA_MODEL_META_KEY_SAMPLING_TOP_K,
LLAMA_MODEL_META_KEY_SAMPLING_TOP_P,
LLAMA_MODEL_META_KEY_SAMPLING_MIN_P,
LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY,
LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD,
LLAMA_MODEL_META_KEY_SAMPLING_TEMP,
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N,
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT,
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT,
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU,
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA,
};
struct llama_model_kv_override {
enum llama_model_kv_override_type tag;
@ -518,6 +533,9 @@ extern "C" {
// Get the number of metadata key/value pairs
LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
// Get sampling metadata key name. Returns nullptr if the key is invalid
LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key);
// Get metadata key name by index
LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);

View File

@ -16,7 +16,7 @@ vendor = {
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h",
"https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h",
}

View File

@ -119,6 +119,18 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
{ LLM_KV_GENERAL_SAMPLING_SEQUENCE, "general.sampling.sequence" },
{ LLM_KV_GENERAL_SAMPLING_TOP_K, "general.sampling.top_k" },
{ LLM_KV_GENERAL_SAMPLING_TOP_P, "general.sampling.top_p" },
{ LLM_KV_GENERAL_SAMPLING_MIN_P, "general.sampling.min_p" },
{ LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability" },
{ LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, "general.sampling.xtc_threshold" },
{ LLM_KV_GENERAL_SAMPLING_TEMP, "general.sampling.temp" },
{ LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, "general.sampling.penalty_last_n" },
{ LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, "general.sampling.penalty_repeat" },
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT, "general.sampling.mirostat" },
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, "general.sampling.mirostat_tau" },
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, "general.sampling.mirostat_eta" },
{ LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" },

View File

@ -123,6 +123,18 @@ enum llm_kv {
LLM_KV_GENERAL_QUANTIZATION_VERSION,
LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_FILE_TYPE,
LLM_KV_GENERAL_SAMPLING_SEQUENCE,
LLM_KV_GENERAL_SAMPLING_TOP_K,
LLM_KV_GENERAL_SAMPLING_TOP_P,
LLM_KV_GENERAL_SAMPLING_MIN_P,
LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY,
LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD,
LLM_KV_GENERAL_SAMPLING_TEMP,
LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N,
LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT,
LLM_KV_GENERAL_SAMPLING_MIROSTAT,
LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU,
LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA,
LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION,

View File

@ -1248,7 +1248,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// make the outputs have the same order they had in the user-provided batch
// note: this is mostly relevant for recurrent models atm
if (!sorted_output) {
if (!sorted_output && n_outputs > 1) {
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?

View File

@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
// organize experts into n_expert_groups
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
// get top n_group_used expert groups
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
cb(expert_groups, "ffn_moe_group_topk", il);
// mask out the other groups
@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
}
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il);

View File

@ -7687,6 +7687,24 @@ int32_t llama_model_meta_count(const llama_model * model) {
return (int)model->gguf_kv.size();
}
const char * llama_model_meta_key_str(llama_model_meta_key key) {
switch (key) {
case LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE: return "general.sampling.sequence";
case LLAMA_MODEL_META_KEY_SAMPLING_TOP_K: return "general.sampling.top_k";
case LLAMA_MODEL_META_KEY_SAMPLING_TOP_P: return "general.sampling.top_p";
case LLAMA_MODEL_META_KEY_SAMPLING_MIN_P: return "general.sampling.min_p";
case LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY: return "general.sampling.xtc_probability";
case LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD: return "general.sampling.xtc_threshold";
case LLAMA_MODEL_META_KEY_SAMPLING_TEMP: return "general.sampling.temp";
case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N: return "general.sampling.penalty_last_n";
case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT: return "general.sampling.penalty_repeat";
case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT: return "general.sampling.mirostat";
case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU: return "general.sampling.mirostat_tau";
case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA: return "general.sampling.mirostat_eta";
default: return nullptr;
}
}
int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)model->gguf_kv.size()) {
if (buf_size > 0) {

View File

@ -39,6 +39,7 @@
#include <string_view>
#include <thread>
#include <vector>
#include <unordered_map>
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
size_t nels = ggml_nelements(tensor);
@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) {
return mse_a_b / mse_a_0;
}
// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
static double jdst(const int32_t * a, const int32_t * b, size_t n) {
std::unordered_map<int32_t, size_t> set_a;
std::unordered_map<int32_t, size_t> set_b;
for (size_t i = 0; i < n; ++i) {
set_a[a[i]]++;
set_b[b[i]]++;
}
size_t diff = 0;
for (const auto & p : set_a) {
const int64_t na = p.second;
const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0;
diff += std::abs(na - nb);
}
for (const auto & p : set_b) {
if (set_a.find(p.first) == set_a.end()) {
diff += p.second;
}
}
return (double) diff / (2*n);
}
// maximum absolute asymmetry between a and b
// asymmetry: (a - b) / (a + b)
// This is more stable than relative error if one of the values fluctuates towards zero.
@ -1051,6 +1080,14 @@ struct test_case {
return 1e-4;
}
virtual double max_err() {
return max_nmse_err();
}
virtual double err(const float * a, const float * b, size_t n) {
return nmse(a, b, n);
}
virtual float grad_eps() {
return 1e-1f;
}
@ -1257,16 +1294,16 @@ struct test_case {
// compare
struct callback_userdata {
bool ok;
double max_err;
test_case * tc;
ggml_backend_t backend1;
ggml_backend_t backend2;
};
callback_userdata ud {
true,
max_nmse_err(),
this,
backend1,
backend2
backend2,
};
auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
@ -1314,9 +1351,9 @@ struct test_case {
}
}
double err = nmse(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) {
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
double err = ud->tc->err(f1.data(), f2.data(), f1.size());
if (err > ud->tc->max_err()) {
printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err());
//for (int i = 0; i < (int) f1.size(); i++) {
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
//}
@ -4943,6 +4980,70 @@ struct test_argsort : public test_case {
}
};
// GGML_OP_TOP_K
struct test_top_k : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const int k;
std::string vars() override {
return VARS_TO_STR3(type, ne, k);
}
test_top_k(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {16, 10, 10, 10},
int k = 4)
: type(type), ne(ne), k(k) {}
double max_err() override {
return 0.0;
}
double err(const float * a, const float * b, size_t n) override {
std::vector<int32_t> ia(n);
std::vector<int32_t> ib(n);
double diff = 0.0f;
for (size_t i = 0; i < n; i++) {
ia[i] = (int32_t) a[i];
ib[i] = (int32_t) b[i];
// penalize the result if the data is not integer valued
diff += std::fabs(a[i] - ia[i]);
diff += std::fabs(b[i] - ib[i]);
}
return diff + jdst(ia.data(), ib.data(), n);
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * out = ggml_top_k(ctx, a, k);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// initialize with unique values to avoid ties
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<float> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
}
}
}
};
struct test_topk_moe : public test_case {
const std::array<int64_t, 4> ne;
const int n_expert_used;
@ -4976,7 +5077,7 @@ struct test_topk_moe: public test_case {
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
@ -7534,6 +7635,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
}
for (int k : {1, 2, 3, 7, 15}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k));
}
// exhaustive top_k tests
//for (int i = 1; i < 9999; ++i) {
// test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
//}
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
@ -7859,6 +7977,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
// Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012
test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) {
for (int nr : { 1, 4, }) {
@ -7911,6 +8032,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
return test_cases;
}

Binary file not shown.

View File

@ -8,6 +8,7 @@
import rehypeKatex from 'rehype-katex';
import rehypeStringify from 'rehype-stringify';
import { copyCodeToClipboard } from '$lib/utils/copy';
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer';
import { preprocessLaTeX } from '$lib/utils/latex-protection';
import { browser } from '$app/environment';
import '$styles/katex-custom.scss';
@ -60,6 +61,7 @@
.use(remarkRehype) // Convert Markdown AST to rehype
.use(rehypeKatex) // Render math using KaTeX
.use(rehypeHighlight) // Add syntax highlighting
.use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables
.use(rehypeStringify); // Convert to HTML string
});

View File

@ -0,0 +1,20 @@
/**
* Matches <br>, <br/>, <br /> tags (case-insensitive).
* Used to detect line breaks in table cell text content.
*/
export const BR_PATTERN = /<br\s*\/?\s*>/gi;
/**
* Matches a complete <ul>...</ul> block.
* Captures the inner content (group 1) for further <li> extraction.
* Case-insensitive, allows multiline content.
*/
export const LIST_PATTERN = /^<ul>([\s\S]*)<\/ul>$/i;
/**
* Matches individual <li>...</li> elements within a list.
* Captures the inner content (group 1) of each list item.
* Non-greedy to handle multiple consecutive items.
* Case-insensitive, allows multiline content.
*/
export const LI_PATTERN = /<li>([\s\S]*?)<\/li>/gi;

View File

@ -0,0 +1,181 @@
/**
* Rehype plugin to restore limited HTML elements inside Markdown table cells.
*
* ## Problem
* The remark/rehype pipeline neutralizes inline HTML as literal text
* (remarkLiteralHtml) so that XML/HTML snippets in LLM responses display
* as-is instead of being rendered. This causes <br> and <ul> markup in
* table cells to show as plain text.
*
* ## Solution
* This plugin traverses the HAST post-conversion, parses whitelisted HTML
* patterns from text nodes, and replaces them with actual HAST element nodes
* that will be rendered as real HTML.
*
* ## Supported HTML
* - `<br>` / `<br/>` / `<br />` - Line breaks (inline)
* - `<ul><li>...</li></ul>` - Unordered lists (block)
*
* ## Key Implementation Details
*
* ### 1. Sibling Combination (Critical)
* The Markdown pipeline may fragment content across multiple text nodes and `<br>`
* elements. For example, `<ul><li>a</li></ul>` might arrive as:
* - Text: `"<ul>"`
* - Element: `<br>`
* - Text: `"<li>a</li></ul>"`
*
* We must combine consecutive text nodes and `<br>` elements into a single string
* before attempting to parse list markup. Without this, list detection fails.
*
* ### 2. visitParents for Deep Traversal
* Table cell content may be wrapped in intermediate elements (e.g., `<p>` tags).
* Using `visitParents` instead of direct child iteration ensures we find text
* nodes at any depth within the cell.
*
* ### 3. Reference Comparison for No-Op Detection
* When checking if `<br>` expansion changed anything, we compare:
* `expanded.length !== 1 || expanded[0] !== textNode`
*
* This catches both cases:
* - Multiple nodes created (text was split)
* - Single NEW node created (original had only `<br>`, now it's an element)
*
* A simple `length > 1` check would miss the single `<br>` case.
*
* ### 4. Strict List Validation
* `parseList()` rejects malformed markup by checking for garbage text between
* `<li>` elements. This prevents creating broken DOM from partial matches like
* `<ul>garbage<li>a</li></ul>`.
*
* ### 5. Newline Substitution for `<br>` in Combined String
* When combining siblings, existing `<br>` elements become `\n` in the combined
* string. This allows list content to span visual lines while still being parsed
* as a single unit.
*
* @example
* // Input Markdown:
* // | Feature | Notes |
* // |---------|-------|
* // | Multi-line | First<br>Second |
* // | List | <ul><li>A</li><li>B</li></ul> |
* //
* // Without this plugin: <br> and <ul> render as literal text
* // With this plugin: <br> becomes line break, <ul> becomes actual list
*/
import type { Plugin } from 'unified';
import type { Element, ElementContent, Root, Text } from 'hast';
import { visit } from 'unist-util-visit';
import { visitParents } from 'unist-util-visit-parents';
import { BR_PATTERN, LIST_PATTERN, LI_PATTERN } from '$lib/constants/table-html-restorer';
/**
* Expands text containing `<br>` tags into an array of text nodes and br elements.
*/
function expandBrTags(value: string): ElementContent[] {
const matches = [...value.matchAll(BR_PATTERN)];
if (!matches.length) return [{ type: 'text', value } as Text];
const result: ElementContent[] = [];
let cursor = 0;
for (const m of matches) {
if (m.index! > cursor) {
result.push({ type: 'text', value: value.slice(cursor, m.index) } as Text);
}
result.push({ type: 'element', tagName: 'br', properties: {}, children: [] } as Element);
cursor = m.index! + m[0].length;
}
if (cursor < value.length) {
result.push({ type: 'text', value: value.slice(cursor) } as Text);
}
return result;
}
/**
* Parses a `<ul><li>...</li></ul>` string into a HAST element.
* Returns null if the markup is malformed or contains unexpected content.
*/
function parseList(value: string): Element | null {
const match = value.trim().match(LIST_PATTERN);
if (!match) return null;
const body = match[1];
const items: ElementContent[] = [];
let cursor = 0;
for (const liMatch of body.matchAll(LI_PATTERN)) {
// Reject if there's non-whitespace between list items
if (body.slice(cursor, liMatch.index!).trim()) return null;
items.push({
type: 'element',
tagName: 'li',
properties: {},
children: expandBrTags(liMatch[1] ?? '')
} as Element);
cursor = liMatch.index! + liMatch[0].length;
}
// Reject if no items found or trailing garbage exists
if (!items.length || body.slice(cursor).trim()) return null;
return { type: 'element', tagName: 'ul', properties: {}, children: items } as Element;
}
/**
* Processes a single table cell, restoring HTML elements from text content.
*/
function processCell(cell: Element) {
visitParents(cell, 'text', (textNode: Text, ancestors) => {
const parent = ancestors[ancestors.length - 1];
if (!parent || parent.type !== 'element') return;
const parentEl = parent as Element;
const siblings = parentEl.children as ElementContent[];
const startIndex = siblings.indexOf(textNode as ElementContent);
if (startIndex === -1) return;
// Combine consecutive text nodes and <br> elements into one string
let combined = '';
let endIndex = startIndex;
for (let i = startIndex; i < siblings.length; i++) {
const sib = siblings[i];
if (sib.type === 'text') {
combined += (sib as Text).value;
endIndex = i;
} else if (sib.type === 'element' && (sib as Element).tagName === 'br') {
combined += '\n';
endIndex = i;
} else {
break;
}
}
// Try parsing as list first (replaces entire combined range)
const list = parseList(combined);
if (list) {
siblings.splice(startIndex, endIndex - startIndex + 1, list);
return;
}
// Otherwise, just expand <br> tags in this text node
const expanded = expandBrTags(textNode.value);
if (expanded.length !== 1 || expanded[0] !== textNode) {
siblings.splice(startIndex, 1, ...expanded);
}
});
}
export const rehypeRestoreTableHtml: Plugin<[], Root> = () => (tree) => {
visit(tree, 'element', (node: Element) => {
if (node.tagName === 'td' || node.tagName === 'th') {
processCell(node);
}
});
};

View File

@ -31,13 +31,16 @@ if (LLAMA_BUILD_BORINGSSL)
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")
include(FetchContent)
FetchContent_Declare(
boringssl
set(BORINGSSL_ARGS
GIT_REPOSITORY ${BORINGSSL_GIT}
GIT_TAG ${BORINGSSL_VERSION}
PATCH_COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_SOURCE_DIR}/patch-boringssl.cmake"
)
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL)
endif()
include(FetchContent)
FetchContent_Declare(boringssl ${BORINGSSL_ARGS})
set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(SAVED_BUILD_TESTING ${BUILD_TESTING})
@ -45,7 +48,15 @@ if (LLAMA_BUILD_BORINGSSL)
set(BUILD_SHARED_LIBS OFF)
set(BUILD_TESTING OFF)
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
FetchContent_MakeAvailable(boringssl)
else()
FetchContent_GetProperties(boringssl)
if(NOT boringssl_POPULATED)
FetchContent_Populate(boringssl)
add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()
endif()
set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS})
set(BUILD_TESTING ${SAVED_BUILD_TESTING})

View File

@ -1087,22 +1087,30 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
// Fallback implementation using thread-based timeout for other Unix systems
struct GetAddrInfoState {
~GetAddrInfoState() {
if (info) { freeaddrinfo(info); }
}
std::mutex mutex;
std::condition_variable result_cv;
bool completed = false;
int result = EAI_SYSTEM;
std::string node = node;
std::string service = service;
struct addrinfo hints = hints;
std::string node;
std::string service;
struct addrinfo hints;
struct addrinfo *info = nullptr;
};
// Allocate on the heap, so the resolver thread can keep using the data.
auto state = std::make_shared<GetAddrInfoState>();
state->node = node;
state->service = service;
state->hints = *hints;
std::thread resolve_thread([=]() {
auto thread_result = getaddrinfo(
state->node.c_str(), state->service.c_str(), hints, &state->info);
std::thread resolve_thread([state]() {
auto thread_result =
getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints,
&state->info);
std::lock_guard<std::mutex> lock(state->mutex);
state->result = thread_result;
@ -1120,6 +1128,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
// Operation completed within timeout
resolve_thread.join();
*res = state->info;
state->info = nullptr; // Pass ownership to caller
return state->result;
} else {
// Timeout occurred
@ -4970,7 +4979,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection,
if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
// Prepare additional headers
if (close_connection || req.get_header_value("Connection") == "close") {
if (close_connection || req.get_header_value("Connection") == "close" ||
400 <= res.status) { // Don't leave connections open after errors
res.set_header("Connection", "close");
} else {
std::string s = "timeout=";
@ -5173,7 +5183,11 @@ bool Server::read_content_core(
size_t /*len*/) { return receiver(buf, n); };
}
if (req.method == "DELETE" && !req.has_header("Content-Length")) {
// RFC 7230 Section 3.3.3: If this is a request message and none of the above
// are true (no Transfer-Encoding and no Content-Length), then the message
// body length is zero (no message body is present).
if (!req.has_header("Content-Length") &&
!detail::is_chunked_transfer_encoding(req.headers)) {
return true;
}
@ -5681,8 +5695,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
// Check if the request URI doesn't exceed the limit
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
Headers dummy;
detail::read_headers(strm, dummy);
res.status = StatusCode::UriTooLong_414;
output_error_log(Error::ExceedUriMaxLength, &req);
return write_response(strm, close_connection, req, res);
@ -6666,11 +6678,13 @@ bool ClientImpl::write_request(Stream &strm, Request &req,
return true;
}
std::unique_ptr<Response> ClientImpl::send_with_content_provider(
std::unique_ptr<Response>
ClientImpl::send_with_content_provider_and_receiver(
Request &req, const char *body, size_t content_length,
ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, Error &error) {
const std::string &content_type, ContentReceiver content_receiver,
Error &error) {
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
@ -6743,15 +6757,24 @@ std::unique_ptr<Response> ClientImpl::send_with_content_provider(
}
}
if (content_receiver) {
req.content_receiver =
[content_receiver](const char *data, size_t data_length,
size_t /*offset*/, size_t /*total_length*/) {
return content_receiver(data, data_length);
};
}
auto res = detail::make_unique<Response>();
return send(req, *res, error) ? std::move(res) : nullptr;
}
Result ClientImpl::send_with_content_provider(
Result ClientImpl::send_with_content_provider_and_receiver(
const std::string &method, const std::string &path, const Headers &headers,
const char *body, size_t content_length, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, UploadProgress progress) {
const std::string &content_type, ContentReceiver content_receiver,
UploadProgress progress) {
Request req;
req.method = method;
req.headers = headers;
@ -6763,9 +6786,10 @@ Result ClientImpl::send_with_content_provider(
auto error = Error::Success;
auto res = send_with_content_provider(
auto res = send_with_content_provider_and_receiver(
req, body, content_length, std::move(content_provider),
std::move(content_provider_without_length), content_type, error);
std::move(content_provider_without_length), content_type,
std::move(content_receiver), error);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
return Result{std::move(res), error, std::move(req.headers), last_ssl_error_,
@ -7094,6 +7118,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length,
content_type, progress);
}
Result ClientImpl::Post(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Post(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -7102,6 +7135,15 @@ Result ClientImpl::Post(const std::string &path,
progress);
}
Result ClientImpl::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Post(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
const Params &params) {
auto query = detail::params_to_query_str(params);
@ -7142,17 +7184,18 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
const char *body, size_t content_length,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("POST", path, headers, body, content_length,
nullptr, nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"POST", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
const std::string &body,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("POST", path, headers, body.data(),
body.size(), nullptr, nullptr, content_type,
progress);
return send_with_content_provider_and_receiver(
"POST", path, headers, body.data(), body.size(), nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7160,18 +7203,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProvider content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("POST", path, headers, nullptr,
content_length, std::move(content_provider),
nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), std::move(progress));
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr,
std::move(content_provider), content_type,
progress);
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), std::move(progress));
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7181,10 +7246,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider(
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items),
content_type, progress);
content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7246,6 +7311,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length,
content_type, progress);
}
Result ClientImpl::Put(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Put(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -7254,6 +7328,15 @@ Result ClientImpl::Put(const std::string &path,
progress);
}
Result ClientImpl::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Put(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
const Params &params) {
auto query = detail::params_to_query_str(params);
@ -7294,17 +7377,18 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
const char *body, size_t content_length,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, body, content_length,
nullptr, nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"PUT", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
const std::string &body,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, body.data(),
body.size(), nullptr, nullptr, content_type,
progress);
return send_with_content_provider_and_receiver(
"PUT", path, headers, body.data(), body.size(), nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7312,18 +7396,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProvider content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, nullptr,
content_length, std::move(content_provider),
nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr,
std::move(content_provider), content_type,
progress);
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7333,10 +7439,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider(
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items),
content_type, progress);
content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7400,6 +7506,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length,
content_type, progress);
}
Result ClientImpl::Patch(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Patch(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -7408,6 +7523,15 @@ Result ClientImpl::Patch(const std::string &path,
progress);
}
Result ClientImpl::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Patch(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const Params &params) {
auto query = detail::params_to_query_str(params);
@ -7448,18 +7572,18 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const char *body, size_t content_length,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, body,
content_length, nullptr, nullptr,
content_type, progress);
return send_with_content_provider_and_receiver(
"PATCH", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const std::string &body,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, body.data(),
body.size(), nullptr, nullptr, content_type,
progress);
return send_with_content_provider_and_receiver(
"PATCH", path, headers, body.data(), body.size(), nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -7467,18 +7591,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProvider content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, nullptr,
content_length, std::move(content_provider),
nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr,
std::move(content_provider), content_type,
progress);
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -7488,10 +7634,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider(
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items),
content_type, progress);
content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -8883,12 +9029,28 @@ Result Client::Post(const std::string &path, size_t content_length,
return cli_->Post(path, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Post(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Post(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return cli_->Post(path, std::move(content_provider), content_type, progress);
}
Result Client::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Post(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
@ -8897,6 +9059,15 @@ Result Client::Post(const std::string &path, const Headers &headers,
return cli_->Post(path, headers, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -8904,6 +9075,14 @@ Result Client::Post(const std::string &path, const Headers &headers,
return cli_->Post(path, headers, std::move(content_provider), content_type,
progress);
}
Result Client::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Params &params) {
return cli_->Post(path, params);
}
@ -8938,8 +9117,8 @@ Result Client::Post(const std::string &path, const Headers &headers,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, body, content_type, content_receiver,
progress);
return cli_->Post(path, headers, body, content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path) { return cli_->Put(path); }
@ -8976,12 +9155,28 @@ Result Client::Put(const std::string &path, size_t content_length,
return cli_->Put(path, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Put(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return cli_->Put(path, std::move(content_provider), content_type, progress);
}
Result Client::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
@ -8990,6 +9185,15 @@ Result Client::Put(const std::string &path, const Headers &headers,
return cli_->Put(path, headers, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -8997,6 +9201,14 @@ Result Client::Put(const std::string &path, const Headers &headers,
return cli_->Put(path, headers, std::move(content_provider), content_type,
progress);
}
Result Client::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Params &params) {
return cli_->Put(path, params);
}
@ -9072,12 +9284,28 @@ Result Client::Patch(const std::string &path, size_t content_length,
return cli_->Patch(path, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Patch(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return cli_->Patch(path, std::move(content_provider), content_type, progress);
}
Result Client::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
@ -9086,6 +9314,15 @@ Result Client::Patch(const std::string &path, const Headers &headers,
return cli_->Patch(path, headers, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -9093,6 +9330,14 @@ Result Client::Patch(const std::string &path, const Headers &headers,
return cli_->Patch(path, headers, std::move(content_provider), content_type,
progress);
}
Result Client::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Params &params) {
return cli_->Patch(path, params);
}

View File

@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.27.0"
#define CPPHTTPLIB_VERSION_NUM "0x001B00"
#define CPPHTTPLIB_VERSION "0.28.0"
#define CPPHTTPLIB_VERSION_NUM "0x001C00"
/*
* Platform compatibility check
@ -257,6 +257,7 @@ using socklen_t = int;
#include <netinet/in.h>
#ifdef __linux__
#include <resolv.h>
#undef _res // Undefine _res macro to avoid conflicts with user code (#2278)
#endif
#include <csignal>
#include <netinet/tcp.h>
@ -1421,14 +1422,18 @@ public:
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Params &params);
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers);
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const Params &params);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1439,14 +1444,18 @@ public:
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Params &params);
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers);
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const Params &params);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1457,14 +1466,18 @@ public:
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Params &params);
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const Params &params);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1712,17 +1725,19 @@ private:
template <typename ClientType> void setup_redirect_client(ClientType &client);
bool handle_request(Stream &strm, Request &req, Response &res,
bool close_connection, Error &error);
std::unique_ptr<Response> send_with_content_provider(
std::unique_ptr<Response> send_with_content_provider_and_receiver(
Request &req, const char *body, size_t content_length,
ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, Error &error);
Result send_with_content_provider(
const std::string &content_type, ContentReceiver content_receiver,
Error &error);
Result send_with_content_provider_and_receiver(
const std::string &method, const std::string &path,
const Headers &headers, const char *body, size_t content_length,
ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, UploadProgress progress);
const std::string &content_type, ContentReceiver content_receiver,
UploadProgress progress);
ContentProviderWithoutLength get_multipart_content_provider(
const std::string &boundary, const UploadFormDataItems &items,
const FormDataProviderItems &provider_items) const;
@ -1775,14 +1790,18 @@ public:
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Params &params);
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers);
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const Params &params);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1793,14 +1812,18 @@ public:
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Params &params);
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers);
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const Params &params);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1811,14 +1834,18 @@ public:
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Params &params);
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers);
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const Params &params);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);

View File

@ -1,6 +0,0 @@
# Remove bssl
file(READ "CMakeLists.txt" content)
string(REPLACE "add_executable(bssl" "#add_executable(bssl" content "${content}")
string(REPLACE "target_link_libraries(bssl" "#target_link_libraries(bssl" content "${content}")
string(REPLACE "install(TARGETS bssl" "#install(TARGETS bssl" content "${content}")
file(WRITE "CMakeLists.txt" "${content}")