Merge branch 'master' into xsn/server_model_management_v1_2

This commit is contained in:
Xuan Son Nguyen 2025-11-26 16:21:57 +01:00
commit becc602612
52 changed files with 2578 additions and 659 deletions

View File

@ -2,10 +2,8 @@
# multiplie collaborators per item can be specified # multiplie collaborators per item can be specified
/.devops/*.Dockerfile @ngxson /.devops/*.Dockerfile @ngxson
/.github/actions/ @slaren @CISC /.github/actions/ @CISC
/.github/workflows/ @CISC /.github/workflows/ @CISC
/.github/workflows/release.yml @slaren
/.github/workflows/winget.yml @slaren
/ci/ @ggerganov /ci/ @ggerganov
/cmake/ @ggerganov /cmake/ @ggerganov
/common/CMakeLists.txt @ggerganov /common/CMakeLists.txt @ggerganov
@ -40,21 +38,14 @@
/examples/passkey/ @ggerganov /examples/passkey/ @ggerganov
/examples/retrieval/ @ggerganov /examples/retrieval/ @ggerganov
/examples/save-load-state/ @ggerganov /examples/save-load-state/ @ggerganov
/examples/simple-chat/ @slaren
/examples/simple/ @slaren
/examples/speculative-simple/ @ggerganov /examples/speculative-simple/ @ggerganov
/examples/speculative/ @ggerganov /examples/speculative/ @ggerganov
/ggml/cmake/ @ggerganov /ggml/cmake/ @ggerganov
/ggml/include/ @ggerganov @slaren /ggml/include/ @ggerganov
/ggml/src/ggml-alloc.c @slaren /ggml/src/ggml-common.h @ggerganov
/ggml/src/ggml-backend* @slaren /ggml/src/ggml-cpu/ @ggerganov
/ggml/src/ggml-blas/ @slaren
/ggml/src/ggml-common.h @ggerganov @slaren
/ggml/src/ggml-cpu/ @ggerganov @slaren
/ggml/src/ggml-cpu/spacemit/ @alex-spacemit /ggml/src/ggml-cpu/spacemit/ @alex-spacemit
/ggml/src/ggml-cuda/common.cuh @slaren
/ggml/src/ggml-cuda/fattn* @JohannesGaessler /ggml/src/ggml-cuda/fattn* @JohannesGaessler
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an /ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler /ggml/src/ggml-cuda/mmq.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler /ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
@ -62,19 +53,19 @@
/ggml/src/ggml-cuda/fattn-wmma* @IMbackK /ggml/src/ggml-cuda/fattn-wmma* @IMbackK
/ggml/src/ggml-hip/ @IMbackK /ggml/src/ggml-hip/ @IMbackK
/ggml/src/ggml-cuda/vendors/hip.h @IMbackK /ggml/src/ggml-cuda/vendors/hip.h @IMbackK
/ggml/src/ggml-impl.h @ggerganov @slaren /ggml/src/ggml-impl.h @ggerganov
/ggml/src/ggml-metal/ @ggerganov /ggml/src/ggml-metal/ @ggerganov
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky /ggml/src/ggml-opencl/ @lhez @max-krasnyansky
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez /ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
/ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-opt.cpp @JohannesGaessler
/ggml/src/ggml-quants.* @ggerganov /ggml/src/ggml-quants.* @ggerganov
/ggml/src/ggml-rpc/ @rgerganov /ggml/src/ggml-rpc/ @rgerganov
/ggml/src/ggml-threading.* @ggerganov @slaren /ggml/src/ggml-threading.* @ggerganov
/ggml/src/ggml-vulkan/ @0cc4m /ggml/src/ggml-vulkan/ @0cc4m
/ggml/src/ggml-webgpu/ @reeselevine /ggml/src/ggml-webgpu/ @reeselevine
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM /ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
/ggml/src/ggml.c @ggerganov @slaren /ggml/src/ggml.c @ggerganov
/ggml/src/ggml.cpp @ggerganov @slaren /ggml/src/ggml.cpp @ggerganov
/ggml/src/gguf.cpp @JohannesGaessler @Green-Sky /ggml/src/gguf.cpp @JohannesGaessler @Green-Sky
/gguf-py/ @CISC /gguf-py/ @CISC
/media/ @ggerganov /media/ @ggerganov
@ -86,15 +77,11 @@
/src/llama-arch.* @CISC /src/llama-arch.* @CISC
/src/llama-chat.* @ngxson /src/llama-chat.* @ngxson
/src/llama-graph.* @CISC /src/llama-graph.* @CISC
/src/llama-model-loader.* @slaren
/src/llama-model.* @CISC /src/llama-model.* @CISC
/src/llama-vocab.* @CISC /src/llama-vocab.* @CISC
/src/models/ @CISC /src/models/ @CISC
/tests/ @ggerganov /tests/ @ggerganov
/tests/test-backend-ops.cpp @slaren
/tests/test-thread-safety.cpp @slaren
/tools/batched-bench/ @ggerganov /tools/batched-bench/ @ggerganov
/tools/llama-bench/ @slaren
/tools/main/ @ggerganov /tools/main/ @ggerganov
/tools/mtmd/ @ngxson /tools/mtmd/ @ngxson
/tools/perplexity/ @ggerganov /tools/perplexity/ @ggerganov
@ -106,8 +93,6 @@
/tools/tokenize/ @ggerganov /tools/tokenize/ @ggerganov
/tools/tts/ @ggerganov /tools/tts/ @ggerganov
/vendor/ @ggerganov /vendor/ @ggerganov
/.clang-format @slaren
/.clang-tidy @slaren
/AUTHORS @ggerganov /AUTHORS @ggerganov
/CMakeLists.txt @ggerganov /CMakeLists.txt @ggerganov
/CONTRIBUTING.md @ggerganov /CONTRIBUTING.md @ggerganov

View File

@ -1237,6 +1237,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
const auto sampler_names = string_split<std::string>(value, ';'); const auto sampler_names = string_split<std::string>(value, ';');
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1266,6 +1267,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.temp = std::stof(value); params.sampling.temp = std::stof(value);
params.sampling.temp = std::max(params.sampling.temp, 0.0f); params.sampling.temp = std::max(params.sampling.temp, 0.0f);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1273,6 +1275,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
[](common_params & params, int value) { [](common_params & params, int value) {
params.sampling.top_k = value; params.sampling.top_k = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1280,6 +1283,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.top_p = std::stof(value); params.sampling.top_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1287,6 +1291,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.min_p = std::stof(value); params.sampling.min_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1301,6 +1306,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.xtc_probability = std::stof(value); params.sampling.xtc_probability = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1308,6 +1314,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.xtc_threshold = std::stof(value); params.sampling.xtc_threshold = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1326,6 +1333,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
params.sampling.penalty_last_n = value; params.sampling.penalty_last_n = value;
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1333,6 +1341,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.penalty_repeat = std::stof(value); params.sampling.penalty_repeat = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1430,6 +1439,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
[](common_params & params, int value) { [](common_params & params, int value) {
params.sampling.mirostat = value; params.sampling.mirostat = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1437,6 +1447,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.mirostat_eta = std::stof(value); params.sampling.mirostat_eta = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
@ -1444,6 +1455,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.sampling.mirostat_tau = std::stof(value); params.sampling.mirostat_tau = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(

View File

@ -8,6 +8,7 @@
#include "common.h" #include "common.h"
#include "log.h" #include "log.h"
#include "llama.h" #include "llama.h"
#include "sampling.h"
#include <algorithm> #include <algorithm>
#include <cinttypes> #include <cinttypes>
@ -957,6 +958,58 @@ std::vector<common_file_info> fs_list(const std::string & path, bool include_dir
// Model utils // Model utils
// //
static inline void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {
const uint64_t config = sparams.user_sampling_config;
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
}
};
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
}
};
// Sampling sequence
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
char buf[512] = {0};
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
if (!sampler_names.empty()) {
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
}
}
}
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}
struct common_init_result common_init_from_params(common_params & params) { struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams; common_init_result iparams;
auto mparams = common_model_params_to_llama(params); auto mparams = common_model_params_to_llama(params);
@ -968,6 +1021,8 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
common_init_sampler_from_model(model, params.sampling);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
auto cparams = common_context_params_to_llama(params); auto cparams = common_context_params_to_llama(params);

View File

@ -138,6 +138,22 @@ struct common_grammar_trigger {
llama_token token = LLAMA_TOKEN_NULL; llama_token token = LLAMA_TOKEN_NULL;
}; };
enum common_params_sampling_config : uint64_t {
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
};
// sampling parameters // sampling parameters
struct common_params_sampling { struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@ -170,6 +186,8 @@ struct common_params_sampling {
bool no_perf = false; // disable performance metrics bool no_perf = false; // disable performance metrics
bool timing_per_token = false; bool timing_per_token = false;
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY

View File

@ -565,7 +565,7 @@ class ModelBase:
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF, gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
) )
) )
or not new_name.endswith(".weight") or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
): ):
data_qtype = gguf.GGMLQuantizationType.F32 data_qtype = gguf.GGMLQuantizationType.F32
@ -10061,6 +10061,25 @@ class LazyTorchTensor(gguf.LazyBase):
torch.uint8: np.uint8, torch.uint8: np.uint8,
} }
# only used when byteswapping data. Only correct size is needed
_dtype_byteswap_map: dict[torch.dtype, type] = {
torch.float64: np.float64,
torch.float32: np.float32,
torch.bfloat16: np.float16,
torch.float16: np.float16,
torch.int64: np.int64,
torch.uint64: np.uint64,
torch.int32: np.int32,
torch.uint32: np.uint32,
torch.int16: np.int16,
torch.uint16: np.uint16,
torch.int8: np.int8,
torch.uint8: np.uint8,
torch.bool: np.uint8,
torch.float8_e4m3fn: np.uint8,
torch.float8_e5m2: np.uint8,
}
# used for safetensors slices # used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
@ -10104,8 +10123,14 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod @classmethod
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor: def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[tensor.dtype] dtype = cls._dtype_str_map[tensor.dtype]
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape) numpy_dtype = cls._dtype_byteswap_map[dtype]
return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype] dtype = cls._dtype_str_map[t.dtype]
shape = t.shape shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r)) lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
@ -10113,10 +10138,16 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod @classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[remote_tensor.dtype] dtype = cls._dtype_str_map[remote_tensor.dtype]
numpy_dtype = cls._dtype_byteswap_map[dtype]
shape = remote_tensor.shape shape = remote_tensor.shape
meta = cls.meta_with_dtype_and_shape(dtype, shape) meta = cls.meta_with_dtype_and_shape(dtype, shape)
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape))
return cast(torch.Tensor, lazy) return cast(torch.Tensor, lazy)
@classmethod @classmethod

View File

@ -242,7 +242,7 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
) )
parser.add_argument( parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
) )
parser.add_argument( parser.add_argument(

View File

@ -530,6 +530,7 @@ extern "C" {
GGML_OP_ARANGE, GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT, GGML_OP_ARGSORT,
GGML_OP_TOP_K,
GGML_OP_LEAKY_RELU, GGML_OP_LEAKY_RELU,
GGML_OP_TRI, GGML_OP_TRI,
GGML_OP_FILL, GGML_OP_FILL,
@ -2258,18 +2259,25 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
enum ggml_sort_order order); enum ggml_sort_order order);
// similar to ggml_top_k but implemented as `argsort` + `view`
GGML_API struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
// top k elements per row
// note: the resulting top k indices are in no particular order
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
GGML_API struct ggml_tensor * ggml_arange( GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx, struct ggml_context * ctx,
float start, float start,
float stop, float stop,
float step); float step);
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
#define GGML_KQ_MASK_PAD 64 #define GGML_KQ_MASK_PAD 64
// q: [n_embd_k, n_batch, n_head, ne3 ] // q: [n_embd_k, n_batch, n_head, ne3 ]

View File

@ -42,6 +42,7 @@
#include <aclnnop/aclnn_exp.h> #include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_fill_scalar.h> #include <aclnnop/aclnn_fill_scalar.h>
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h> #include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <aclnnop/aclnn_ger.h>
#include <aclnnop/aclnn_group_norm.h> #include <aclnnop/aclnn_group_norm.h>
#include <aclnnop/aclnn_grouped_matmul_v3.h> #include <aclnnop/aclnn_grouped_matmul_v3.h>
#include <aclnnop/aclnn_gt_scalar.h> #include <aclnnop/aclnn_gt_scalar.h>
@ -2206,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx,
} }
/** /**
* @brief Initializes and caches sine/cosine positional encoding values * @brief Initializes and caches all intermediate tensors required for RoPE
* (used in RoPE, Rotary Position Embedding) for attention layers. * (Rotary Position Embedding), including support for Yarn, mRoPE,
* i-mRoPE, Neox repeat strategy, independent sectors, frequency factors
* and multi-section rotary groups.
* *
* This function computes and caches the sin/cos values of * This function computes and caches the per-dimension θ coefficients used for
* θ = position * theta_scale for RoPE encoding. The cache is shared * Q/K rotary embedding. The cache is shared across layers, and recomputed only
* across attention layers, and only the first attention layer will * when any dependent parameter changes.
* trigger initialization. The cache includes repeated sin/cos values
* with different repeat methods depending on the @param is_neox flag.
* *
* Steps performed by this function: * The function now supports:
* 1. Identify whether the target tensor belongs to Q/K in attention * - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor)
* and restrict computation to the first layer only. * - Per-dimension independent sector exponent rules (indep_sects + sections[])
* 2. Initialize the theta scale array (arange power freq scaling). * - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope)
* 3. Allocate sin/cos caches if the max prompt length increases. * - Frequency factor division (src2)
* 4. Compute θ = position * theta_scale. * - Neox / normal repeat expansion modes
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled.
* *
* @param ctx The CANN backend context, holding memory pool, * @param ctx CANN backend context, containing memory pool,
* stream, and persistent buffers for rope init/cache. * cached buffers, and runtime stream.
* @param dst The destination ggml_tensor whose computation * @param dst Destination ggml_tensor whose computation
* depends on the RoPE values (usually Qcur/Kcur). * depends on RoPE (typically Qcur or Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values. * @param corr_dims [low, high] Yarn correction range.
* @param freq_scale Frequency scaling factor, applied to theta scale. * @param ext_factor Yarn extrapolation strength. 0 = disabled.
* @param attn_factor Attention scaling factor, applied to sin/cos. * @param theta_scale Base multiplier for per-dimension θ exponent.
* @param is_neox Whether to use Neox-style repeat strategy * @param freq_scale Global frequency scaling factor.
* (dim expansion vs repeat_interleave). * @param attn_factor Optional scaling applied to sin/cos (if needed).
* @param is_neox Whether to use Neox-style dimension interleave.
* @param sections 4-way sector sizes for independent-section RoPE
* and multi-section mRoPE (t/h/w/e).
* @param mrope_used Whether to enable multi-section rotary embedding.
* @param is_imrope Whether to apply interleaved mRoPE rules.
* @param indep_sects Whether each dimension runs independent exponent
* resets based on @p sections.
*/ */
static void aclnn_cache_init(ggml_backend_cann_context & ctx, static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
ggml_tensor * dst, ggml_tensor * dst,
float * corr_dims, float * corr_dims,
float ext_factor, float ext_factor,
float theta_scale, float theta_scale,
float freq_scale, float freq_scale,
float attn_factor, float attn_factor,
bool is_neox) { bool is_neox,
int sections[4],
bool mrope_used,
bool is_imrope,
bool indep_sects) {
ggml_tensor * src0 = dst->src[0]; // input ggml_tensor * src0 = dst->src[0]; // input
ggml_tensor * src1 = dst->src[1]; // position ggml_tensor * src1 = dst->src[1]; // position
ggml_tensor * src2 = dst->src[2]; // freq_factors ggml_tensor * src2 = dst->src[2]; // freq_factors
if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor && int64_t theta_scale_length = src0->ne[0] / 2;
ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale && int64_t position_length = dst->ne[2];
ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) {
// TODO: check theta_scale_length and position_length.
if (src2 == nullptr && ctx.rope_cache.cached &&
ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor,
is_neox, indep_sects, mrope_used, is_imrope, sections)) {
// use cache. // use cache.
return; return;
} }
int64_t theta_scale_length = src0->ne[0] / 2; // Step0: calculate tensor shape.
int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) }; size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float),
theta_scale_length * sizeof(float) };
GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->type == GGML_TYPE_I32);
int64_t position_length = src1->ne[0]; int64_t position_ne[] = { 1, 1, position_length, 1 };
int64_t position_ne[] = { 1, 1, position_length, 1 }; size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 }; int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 };
size_t theta_nb[GGML_MAX_DIMS]; size_t cache_nb[GGML_MAX_DIMS];
theta_nb[0] = sizeof(float); cache_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1];
} }
// theta_scale arange, [0,1,...,ne00/2 - 1] // Step1: Compute the coefficient of theta. During the cache_init process, aside from
// (1) multiplying by the position,
// (2) dividing by freq_factors,
// (3) computing the sine and cosine,
// the other parameters used in the computation generally do not change in most scenarios.
// Therefore, we can first compute this part of the result and then cache it.
// Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor.
acl_tensor_ptr acl_theta_scale_tensor; acl_tensor_ptr acl_theta_scale_tensor;
// cache theta scale bool theta_scale_updated = false;
if (ctx.rope_cache.theta_scale_length != theta_scale_length || if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale ||
// theta_scale and freq_scale should not change during the current token inference process, ctx.rope_cache.indep_sects != indep_sects) {
// so we can directly use == here instead of comparing the absolute difference. theta_scale_updated = true;
ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) { if (ctx.rope_cache.theta_scale_exp_host != nullptr) {
ctx.rope_cache.theta_scale_length = theta_scale_length; free(ctx.rope_cache.theta_scale_exp_host);
}
ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float));
GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr);
if (!indep_sects) {
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
} else {
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) {
ctx.rope_cache.theta_scale_exp_host[i] = 1;
continue;
}
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
}
if (ctx.rope_cache.theta_scale_cache != nullptr) { if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
@ -2285,74 +2328,138 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ACL_MEM_MALLOC_HUGE_FIRST)); ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, 1); theta_scale_ne, theta_scale_nb, 1);
}
float start = 0; // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
float step = 1; bool yarn_ramp_tensor_updated = false;
float stop = theta_scale_length; ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
float n_elements = theta_scale_length; acl_tensor_ptr acl_yarn_ramp_tensor;
aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements); if (ext_factor != 0 &&
// TODO: check more parameter.
(ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
yarn_ramp_tensor_updated = true;
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); // -rope_yarn_ramp
acl_tensor_ptr acl_yarn_ramp_tensor; // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
if (ext_factor != 0) { // return MIN(1, MAX(0, y)) - 1;
// -rope_yarn_ramp yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low); void * yarn_ramp_buffer = yarn_ramp_allocator.get();
// return MIN(1, MAX(0, y)) - 1; acl_yarn_ramp_tensor =
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
void * yarn_ramp_buffer = yarn_ramp_allocator.get(); float zero_value = 0, one_value = 1;
acl_yarn_ramp_tensor = float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
float zero_value = 0, one_value = 1; acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT); acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(), aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length);
acl_yarn_ramp_tensor.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());
// theta_interp = freq_scale * theta_extrap; // theta_interp = freq_scale * theta_extrap;
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
// //
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
// cache freq_scale + (freq_scale - 1) * ramp_mix // cache freq_scale + (freq_scale - 1) * ramp_mix
float freq_scale_1 = freq_scale - 1; float freq_scale_1 = freq_scale - 1;
acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT); acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
} }
// power // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT); if (ext_factor != 0) {
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(), if (theta_scale_updated || yarn_ramp_tensor_updated) {
acl_theta_scale_tensor.get()); theta_scale_updated = true;
if (ext_factor != 0) {
aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get()); aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get());
} else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
} }
} else { } else {
// use cache if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) {
theta_scale_updated = true;
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
}
}
// Nothing changed, use cache.
if (!theta_scale_updated) {
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
} }
// Step 1.4: prepare select index if mrope
acl_tensor_ptr position_select_index_tensor;
if (mrope_used) {
if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] ||
ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] ||
ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) {
if (ctx.rope_cache.position_select_index_host != nullptr) {
free(ctx.rope_cache.position_select_index_host);
}
ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int));
GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr);
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
// t,h,w,e
for (int i = 0; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (is_imrope) { // qwen3vl apply interleaved mrope
if (sector % 3 == 1 && sector < 3 * sections[1]) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
ctx.rope_cache.position_select_index_host[i] = 0;
} else {
ctx.rope_cache.position_select_index_host[i] = 3;
}
} else {
if (sector >= sections[0] && sector < sec_w) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector >= sec_w && sector < sec_e) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector >= sec_e) {
ctx.rope_cache.position_select_index_host[i] = 3;
} else {
ctx.rope_cache.position_select_index_host[i] = 0;
}
}
}
if (ctx.rope_cache.position_select_index != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
}
position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32,
sizeof(int), theta_scale_ne, theta_scale_nb, 1);
}
// Step2: divide by freq_factors
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool()); ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
// freq_factors
if (src2) { if (src2) {
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float)); freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
void * freq_fac_res_ptr = freq_fac_res_allocator.get(); void * freq_fac_res_ptr = freq_fac_res_allocator.get();
@ -2365,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor); std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
} }
// Step3: prepare position_tensor
acl_tensor_ptr acl_position_tensor;
ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool());
if (mrope_used) {
// Step3.1: select current position;
// position :
// pos1: [[0, 1 ,2 ,3 ],
// pos2: [4, 5 ,6 ,7 ],
// pos3: [8, 9 ,10,11],
// pos4: [12,13,14,15] ]
//
// select index = [0, 1, 2, 2, 1, 0]
//
// selected_tensor:
// [[0, 1 ,2 ,3 ],
// [4, 5 ,6 ,7 ],
// [8, 9 ,10,11],
// [8, 9 ,10,11],
// [4, 5 ,6 ,7 ],
// [0, 1 ,2 ,3 ]]
//
// transpose, from [seq_len:dims] to [dims:seq_len]
// [0, 4, 8 ,8 ,4, 0],
// [1, 5, 9, 9, 5, 1],
// [2, 6, 10,10,6 ,2],
// [3, 7, 11,11,7 3 ]]
//
// multipy by theta_scale_tensor
// [theta_scale^0, theta_scale^1, ..., theta_scale ^ n]
int64_t mrope_position_ne[] = { position_length, 4 };
size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) };
acl_tensor_ptr mrope_position =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
mrope_position_ne, mrope_position_nb, 2);
// selected position tensor's shape is a transpose of cache tensor.
int64_t selected_position_ne[] = { position_length, theta_scale_length };
size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) };
mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float));
void * mrope_position_buffer = mrope_position_acllocator.get();
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(),
acl_position_tensor.get());
// transpose
int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 };
size_t transposed_nb[GGML_MAX_DIMS];
transposed_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1];
}
std::swap(transposed_ne[0], transposed_ne[2]);
std::swap(transposed_nb[0], transposed_nb[2]);
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS);
} else {
// auto bcast.
acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
position_ne, position_nb, GGML_MAX_DIMS);
}
// Step4: multiply by the position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// Step5: calculate sin cos.
// init sin_repeat && cos_repeat, only to accelerate first layer on each device // init sin_repeat && cos_repeat, only to accelerate first layer on each device
if (position_length > ctx.rope_cache.position_length) { if (position_length > ctx.rope_cache.position_length) {
ctx.rope_cache.position_length = position_length; ctx.rope_cache.position_length = position_length;
@ -2381,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
} }
// position
acl_tensor_ptr acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne,
position_nb, GGML_MAX_DIMS);
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// sin/cos // sin/cos
ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float)); ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float));
void * sin_buffer = sin_allocator.get(); void * sin_buffer = sin_allocator.get();
acl_tensor_ptr acl_sin_tensor = acl_tensor_ptr acl_sin_tensor =
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get()); aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get());
ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float)); ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float));
void * cos_buffer = cos_allocator.get(); void * cos_buffer = cos_allocator.get();
acl_tensor_ptr acl_cos_tensor = acl_tensor_ptr acl_cos_tensor =
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get()); aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get());
if (ext_factor != 0) { if (ext_factor != 0) {
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale); attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
} }
// attn_factor // Step 5: multiply by attn_factor
if (attn_factor != 1) { if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true); aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true);
aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true); aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
} }
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 }; int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS]; size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float); sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
@ -2429,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat // Step 6: repeat
if (is_neox) { if (is_neox) {
// [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn]
int64_t repeatsArray[] = { 1, 1, 1, 2 }; int64_t repeatsArray[] = { 1, 1, 1, 2 };
aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray); aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray);
aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray); aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray);
@ -2438,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
int64_t num_repeats = 2; int64_t num_repeats = 2;
int64_t dim = 3; int64_t dim = 3;
int64_t output_size = theta_scale_length * num_repeats; int64_t output_size = theta_scale_length * num_repeats;
// [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn]
aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size);
aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size);
} }
// Other layers use cache except first layer. // Update cached value.
ctx.rope_cache.cached = true; ctx.rope_cache.cached = true;
ctx.rope_cache.ext_factor = ext_factor; ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox,
ctx.rope_cache.theta_scale = theta_scale; indep_sects, mrope_used, is_imrope, sections);
ctx.rope_cache.freq_scale = freq_scale;
ctx.rope_cache.attn_factor = attn_factor;
ctx.rope_cache.is_neox = is_neox;
} }
#ifdef __cplusplus #ifdef __cplusplus
@ -2474,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// param // param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
// const int n_past = ((int32_t *) dst->op_params)[0]; // const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
@ -2482,12 +2654,13 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
GGML_TENSOR_UNARY_OP_LOCALS GGML_TENSOR_UNARY_OP_LOCALS
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
// TODO: n_dims <= ne0 // TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0); GGML_ASSERT(n_dims == ne0);
@ -2498,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
float corr_dims[2]; float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (mrope_used) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
if (is_imrope || mrope_used) {
is_neox = true;
}
// init ctx.rope_cos/rope_sin cache // init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox); aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 }; int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS]; size_t sin_reshape_nb[GGML_MAX_DIMS];
@ -2657,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
return; return;
#endif #endif
// ggml_mode = 0 --> aclnn_model = 1 int64_t acl_mode = is_neox ? 0 : 1;
int64_t acl_mode = mode == 0 ? 1 : mode;
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
@ -3236,3 +3423,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
GGML_ABORT("Function is not implemented."); GGML_ABORT("Function is not implemented.");
} }
} }
static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // weight
ggml_tensor * src1 = dst->src[1]; // input
GGML_TENSOR_BINARY_OP_LOCALS
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t i02 = i2 / dps2;
const int64_t i03 = i3 / dps3;
const int64_t i12 = i2;
const int64_t i13 = i3;
acl_tensor_ptr accumulator =
ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
// The outer product needs to be accumulated in this dimension.
for (int64_t i1 = 0; i1 < ne11; i1++) {
acl_tensor_ptr acl_input = ggml_cann_create_tensor(
(char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src1->ne, src1->nb, 1);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(
(char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src0->ne, src0->nb, 1);
ggml_cann_pool_alloc output_allocator(ctx.pool());
void * output_buffer = output_allocator.alloc(ggml_nbytes(dst));
acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
float alpha_value = 1.0f;
aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
}
}
}
}
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
const enum ggml_type type = src0->type;
switch (type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
ggml_cann_out_prod_fp(ctx, dst);
break;
default:
GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD");
break;
}
}

View File

@ -1125,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
} while (0) } while (0)
#endif // CANN_ACLNN_OPS #endif // CANN_ACLNN_OPS
/**
* @brief Performs outer product operation on two ggml tensors using the CANN backend.
*
* @details This function computes the outer product of two input tensors (src0 and src1)
* and stores the result in the destination tensor. The outer product operation is defined as:
* dst[i,j,k,l] = sum_m (src0[i,m,k,l] * src1[j,m,k,l])
*
* The function supports multiple data types including F32, F16. For floating-point
* types, it uses batch matrix multiplication for efficient computation.
*
* The implementation handles 4D tensor broadcasting and batch processing automatically.
*
* @param ctx The CANN backend context for operation execution and memory management.
* @param dst The destination ggml_tensor where the outer product result will be stored.
* The input tensors are assumed to be `dst->src[0]` and `dst->src[1]`.
*
* @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation
*/
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst);

View File

@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache {
struct ggml_cann_rope_cache { struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() { ~ggml_cann_rope_cache() {
if (theta_scale_cache != nullptr) { if (theta_scale_cache) {
ACL_CHECK(aclrtFree(theta_scale_cache)); ACL_CHECK(aclrtFree(theta_scale_cache));
} }
if (sin_cache != nullptr) { if (sin_cache) {
ACL_CHECK(aclrtFree(sin_cache)); ACL_CHECK(aclrtFree(sin_cache));
} }
if (cos_cache != nullptr) { if (cos_cache) {
ACL_CHECK(aclrtFree(cos_cache)); ACL_CHECK(aclrtFree(cos_cache));
} }
if (position_select_index) {
ACL_CHECK(aclrtFree(position_select_index));
}
if (theta_scale_exp_host) {
free(theta_scale_exp_host);
}
if(position_select_index_host) {
free(position_select_index_host);
}
} }
void * theta_scale_cache = nullptr; bool equal(int64_t theta_scale_length,
int64_t theta_scale_length = 0; int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
}
void set(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
this->theta_scale_length = theta_scale_length;
this->position_length = position_length;
this->ext_factor = ext_factor;
this->theta_scale = theta_scale;
this->freq_scale = freq_scale;
this->attn_factor = attn_factor;
this->is_neox = is_neox;
this->indep_sects = indep_sects;
this->mrope_used = mrope_used;
this->is_imrope = is_imrope;
this->sections[0] = sections[0];
this->sections[1] = sections[1];
this->sections[2] = sections[2];
this->sections[3] = sections[3];
}
// memory cache, prepare before inferencing.
void * theta_scale_cache = nullptr;
float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
// sin/cos cache, used only to accelerate first layer on each device // sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr; void * sin_cache = nullptr;
void * cos_cache = nullptr; void * cos_cache = nullptr;
int64_t position_length = 0;
// Properties to check before reusing the sincos cache // Properties to check before reusing the sincos cache
bool cached = false; int64_t theta_scale_length = 0;
float ext_factor = 0.0f; int64_t position_length = 0;
float theta_scale = 0.0f; bool cached = false;
float freq_scale = 0.0f; float ext_factor = 0.0f;
float attn_factor = 0.0f; float theta_scale = 0.0f;
bool is_neox = false; float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
bool indep_sects = false;
bool mrope_used = false;
int sections[4] = { 0, 0, 0, 0 };
bool is_imrope = false;
}; };
struct ggml_cann_tensor_cache { struct ggml_cann_tensor_cache {

View File

@ -1886,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst); ggml_cann_flash_attn_ext(ctx, dst);
break; break;
case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst);
break;
default: default:
return false; return false;
} }
@ -2477,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
return false; return false;
} }
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
if (op->src[0]->ne[0] > 896) { if (op->src[0]->ne[0] > 896) {
return false; return false;
} }
@ -2563,6 +2559,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_PAD_REFLECT_1D: case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
return true; return true;
case GGML_OP_OUT_PROD:
{
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
default:
return false;
}
}
case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_1D:
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255. // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
return (op->src[0]->ne[0] - 1) <= 255; return (op->src[0]->ne[0] - 1) <= 255;

View File

@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
include(CheckCXXSourceCompiles) include(CheckCXXSourceCompiles)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}") string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}")
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}")
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
set(ARM_FEATURE "HAVE_${feature}") set(ARM_FEATURE "HAVE_${feature}")
check_cxx_source_compiles( check_cxx_source_compiles(

View File

@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_argsort(params, tensor); ggml_compute_forward_argsort(params, tensor);
} break; } break;
case GGML_OP_TOP_K:
{
ggml_compute_forward_top_k(params, tensor);
} break;
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
{ {
ggml_compute_forward_leaky_relu(params, tensor); ggml_compute_forward_leaky_relu(params, tensor);
@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
} break; } break;
case GGML_OP_TOP_K:
{
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
} break;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
const int64_t ne10 = node->src[1]->ne[0]; // DK const int64_t ne10 = node->src[1]->ne[0]; // DK

View File

@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
// ggml_compute_forward_argsort // ggml_compute_forward_argsort
template<enum ggml_sort_order order> template<enum ggml_sort_order order>
struct argsort_cmp { struct cmp_argsort {
const float * data; const float * data;
bool operator()(int32_t a, int32_t b) const { bool operator()(int32_t a, int32_t b) const {
if constexpr (order == GGML_SORT_ORDER_ASC) { if constexpr (order == GGML_SORT_ORDER_ASC) {
@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
switch (order) { switch (order) {
case GGML_SORT_ORDER_ASC: case GGML_SORT_ORDER_ASC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data}); std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
break; break;
case GGML_SORT_ORDER_DESC: case GGML_SORT_ORDER_DESC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data}); std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
break; break;
default: default:
@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
} }
} }
// ggml_compute_forward_top_k
struct cmp_top_k {
const float * data;
bool operator()(int32_t a, int32_t b) const {
return data[a] > data[b];
}
};
static void ggml_compute_forward_top_k_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb0 == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_nrows(src0);
const int top_k = ne0;
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
for (int64_t i = ith; i < nr; i += nth) {
const float * src_data = (float *)((char *) src0->data + i*nb01);
for (int64_t j = 0; j < ne00; j++) {
tmp[j] = j;
}
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
std::copy(tmp, tmp + top_k, dst_data);
// emphasize that the order is not important
if (top_k > 1) {
std::swap(dst_data[0], dst_data[1]);
}
}
}
void ggml_compute_forward_top_k(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_top_k_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_flash_attn_ext // ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(

View File

@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
} }
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD) #if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
#if defined(__ARM_FEATURE_SVE) const int sve_register_length = svcntb() * 8;
const int sve_register_length = svcntb() * 8; const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_epr = sve_register_length / 16; const int ggml_f16_step = 8 * ggml_f16_epr;
const int ggml_f16_step = 8 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np= (n & ~(ggml_f16_step - 1)); int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) { for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
GGML_F16x_VEC_STORE(y + k, ry, 0);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
np = n;
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
const int np = n;
_Float16 hv = (_Float16)v;
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e16m8(n - i);
vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
__riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
}
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
} }
const int np2 = (n & ~(ggml_f16_epr - 1)); }
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
GGML_F16x_VEC_STORE(y + k, ry, 0);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
#else #else
// scalar const int np = 0;
for (int i = 0; i < n; ++i) { #endif
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
} }
#endif
} }
// xs and vs are byte strides of x and v // xs and vs are byte strides of x and v

View File

@ -73,34 +73,7 @@ namespace ggml_cuda_mma {
static constexpr int I = I_; static constexpr int I = I_;
static constexpr int J = J_; static constexpr int J = J_;
#if defined(GGML_USE_HIP) #if defined(AMD_MFMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
static constexpr int ne = I * J / 64; static constexpr int ne = I * J / 64;
T x[ne] = {0}; T x[ne] = {0};
@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
return -1; return -1;
} }
} }
#endif // defined(RDNA4)
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I * J / 32; static constexpr int ne = I * J / 32;
T x[ne] = {0}; T x[ne] = {0};
@ -177,6 +149,34 @@ namespace ggml_cuda_mma {
return -1; return -1;
} }
} }
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#endif
#else #else
static constexpr int ne = I * J / 32; static constexpr int ne = I * J / 32;
T x[ne] = {0}; T x[ne] = {0};
@ -437,7 +437,29 @@ namespace ggml_cuda_mma {
xi[0] = xs[0]; xi[0] = xs[0];
} }
#elif defined(AMD_WMMA_AVAILABLE) #elif defined(AMD_WMMA_AVAILABLE)
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
} else if constexpr (std::is_same_v<T, int>) {
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
}else{
NO_DEVICE_CODE;
}
} else {
NO_DEVICE_CODE;
}
#else #else
#pragma unroll #pragma unroll
for (int l = 0; l < t.ne; ++l) { for (int l = 0; l < t.ne; ++l) {
@ -772,6 +794,36 @@ namespace ggml_cuda_mma {
acc[0], acc[0],
0, 0, 0); 0, 0, 0);
#endif // defined(CDNA3) #endif // defined(CDNA3)
#elif defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;
#if defined(RDNA4)
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#endif // defined(RDNA4)
#else #else
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
@ -798,6 +850,7 @@ namespace ggml_cuda_mma {
acc[0], acc[0],
0, 0, 0); 0, 0, 0);
#endif // defined(CDNA3) #endif // defined(CDNA3)
#else #else
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
@ -842,4 +895,31 @@ namespace ggml_cuda_mma {
mma(D16[1], A16[1], B); mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
} }
static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
#if defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
false
);
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif
}
} }

View File

@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false; return false;
} }
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; if (amd_wmma_available(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return true;
}
}
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
} }

File diff suppressed because it is too large Load Diff

View File

@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
return res; return res;
} }
// note: reuse the argsort kernel for top_k
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_TOP_K);
char base[256];
char name[256];
// note: the top_k kernel is always descending order
ggml_sort_order order = GGML_SORT_ORDER_DESC;
const char * order_str = "undefined";
switch (order) {
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_TOP_K);
char base[256];
char name[256];
ggml_sort_order order = GGML_SORT_ORDER_DESC;
const char * order_str = "undefined";
switch (order) {
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib, ggml_metal_library_t lib,
const struct ggml_tensor * op, const struct ggml_tensor * op,

View File

@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:

View File

@ -832,14 +832,19 @@ typedef struct {
} ggml_metal_kargs_leaky_relu; } ggml_metal_kargs_leaky_relu;
typedef struct { typedef struct {
int64_t ne00; int32_t ne00;
int64_t ne01; int32_t ne01;
int64_t ne02; int32_t ne02;
int64_t ne03; int32_t ne03;
uint64_t nb00; uint64_t nb00;
uint64_t nb01; uint64_t nb01;
uint64_t nb02; uint64_t nb02;
uint64_t nb03; uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
int32_t top_k;
} ggml_metal_kargs_argsort; } ggml_metal_kargs_argsort;
typedef struct { typedef struct {
@ -851,6 +856,11 @@ typedef struct {
uint64_t nb01; uint64_t nb01;
uint64_t nb02; uint64_t nb02;
uint64_t nb03; uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
int32_t top_k;
int32_t len; int32_t len;
} ggml_metal_kargs_argsort_merge; } ggml_metal_kargs_argsort_merge;

View File

@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{ {
n_fuse = ggml_metal_op_argsort(ctx, idx); n_fuse = ggml_metal_op_argsort(ctx, idx);
} break; } break;
case GGML_OP_TOP_K:
{
n_fuse = ggml_metal_op_top_k(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
{ {
n_fuse = ggml_metal_op_leaky_relu(ctx, idx); n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
@ -3678,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
} }
ggml_metal_kargs_argsort args = { ggml_metal_kargs_argsort args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
/*.ne01 =*/ ne01, /*.ne01 =*/ ne01,
/*.ne02 =*/ ne02, /*.ne02 =*/ ne02,
/*.ne03 =*/ ne03, /*.ne03 =*/ ne03,
/*.nb00 =*/ nb00, /*.nb00 =*/ nb00,
/*.nb01 =*/ nb01, /*.nb01 =*/ nb01,
/*.nb02 =*/ nb02, /*.nb02 =*/ nb02,
/*.nb03 =*/ nb03, /*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ nth,
}; };
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx); ggml_metal_op_concurrency_reset(ctx);
ggml_metal_kargs_argsort_merge args_merge = { ggml_metal_kargs_argsort_merge args_merge = {
.ne00 = ne00, /*.ne00 =*/ ne00,
.ne01 = ne01, /*.ne01 =*/ ne01,
.ne02 = ne02, /*.ne02 =*/ ne02,
.ne03 = ne03, /*.ne03 =*/ ne03,
.nb00 = nb00, /*.nb00 =*/ nb00,
.nb01 = nb01, /*.nb01 =*/ nb01,
.nb02 = nb02, /*.nb02 =*/ nb02,
.nb03 = nb03, /*.nb03 =*/ nb03,
.len = len, /*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ ne00,
/*.len =*/ len,
}; };
// merges per row // merges per row
@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
return 1; return 1;
} }
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
// bitonic sort requires the number of elements to be power of 2
int nth = 1;
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
// blocks per row
const int npr = (ne00 + nth - 1)/nth;
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
std::swap(bid_dst, bid_tmp);
}
const int top_k = ne0;
ggml_metal_kargs_argsort args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
};
if (npr > 1) {
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
}
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
int len = args.top_k;
while (len < args.ne0) {
ggml_metal_op_concurrency_reset(ctx);
// merges per row
const int nm = (args.ne0 + 2*len - 1) / (2*len);
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
ggml_metal_kargs_argsort_merge args_merge = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ args.ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
/*.len =*/ len,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
std::swap(bid_dst, bid_tmp);
len <<= 1;
}
return 1;
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx); ggml_tensor * op = ctx->node(idx);

View File

@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);

View File

@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
{ {
res *= 2; res *= 2;
} break; } break;
case GGML_OP_TOP_K:
{
res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
} break;
default: default:
break; break;
} }

View File

@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32(
ushort3 ntg[[threads_per_threadgroup]]) { ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort // bitonic sort
const int col = tpitg[0]; const int col = tpitg[0];
const int ib = tgpig[0] / args.ne01;
const int i00 = (tgpig[0]/args.ne01)*ntg.x; const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01; const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1]; const int i02 = tgpig[1];
const int i03 = tgpig[2]; const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32(
} }
} }
const int64_t i0 = ib*args.top_k;
// copy the result to dst without the padding // copy the result to dst without the padding
if (i00 + col < args.ne00) { if (i0 + col < args.ne0 && col < args.top_k) {
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03; dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col]; dst[col] = shmem_i32[col];
} }
@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32(
const int start = im * (2 * args.len); const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1; const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start device const int32_t * tmp0 = tmp + start
+ i01*args.ne00 + i01*args.ne0
+ i02*args.ne00*args.ne01 + i02*args.ne0*args.ne01
+ i03*args.ne00*args.ne01*args.ne02; + i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len; device const int32_t * tmp1 = tmp0 + args.len;
dst += start dst += start
+ i01*args.ne00 + i01*args.top_k
+ i02*args.ne00*args.ne01 + i02*args.top_k*args.ne01
+ i03*args.ne00*args.ne01*args.ne02; + i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0 device const float * src0_row = (device const float *)(src0
+ args.nb01*i01 + args.nb01*i01
@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32(
const int chunk = (total + ntg.x - 1) / ntg.x; const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk; const int k0 = tpitg.x * chunk;
const int k1 = min(k0 + chunk, total); const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
if (k0 >= args.top_k) {
return;
}
if (k0 >= total) { if (k0 >= total) {
return; return;

View File

@ -705,6 +705,7 @@ struct vk_device_struct {
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@ -1629,6 +1630,22 @@ class vk_perf_logger {
timings[name].push_back(time); timings[name].push_back(time);
return; return;
} }
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
const ggml_tensor * dst = node;
const ggml_tensor * q = node->src[0];
const ggml_tensor * k = node->src[1];
const ggml_tensor * v = node->src[2];
const ggml_tensor * m = node->src[3];
std::stringstream name;
name << ggml_op_name(node->op) <<
" dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " <<
" q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " <<
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
timings[name.str()].push_back(time);
return;
}
timings[ggml_op_name(node->op)].push_back(time); timings[ggml_op_name(node->op)].push_back(time);
} }
private: private:
@ -2485,9 +2502,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
if (hsv >= 192) { if (hsv >= 192) {
return 2; return 2;
} else if ((hsv | hsk) & 8) {
return 4;
} else { } else {
return 8; return 8;
} }
@ -2519,9 +2538,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if ((hsv | hsk) & 8) { if ((hsv | hsk) & 8) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
return {get_fa_scalar_num_large_rows(hsv), 64}; return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
} else { } else {
return {get_fa_scalar_num_large_rows(hsv), 32}; return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
} }
} }
} }
@ -3950,6 +3969,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
#define IM2COL(bda) \ #define IM2COL(bda) \
@ -7724,7 +7745,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
// Needs to be kept up to date on shader changes // Needs to be kept up to date on shader changes
GGML_UNUSED(hsv); GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size; const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = get_fa_scalar_num_large_rows(hsv); const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpsh = wg_size * sizeof(float);
@ -7855,7 +7876,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR: case FA_SCALAR:
case FA_COOPMAT1: case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both // We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = get_fa_scalar_num_large_rows(HSV); max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
break; break;
case FA_COOPMAT2: case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2); max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@ -8439,6 +8460,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_sum_rows_f32; return ctx->device->pipeline_sum_rows_f32;
} }
return nullptr; return nullptr;
case GGML_OP_CUMSUM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cumsum_f32;
}
return nullptr;
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argmax_f32; return ctx->device->pipeline_argmax_f32;
@ -8803,6 +8829,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_CUMSUM:
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
{ {
@ -10132,6 +10159,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
} }
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
} }
@ -11731,6 +11763,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
ggml_vk_sum_rows(ctx, compute_ctx, src0, node); ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CUMSUM:
ggml_vk_cumsum(ctx, compute_ctx, src0, node);
break; break;
case GGML_OP_MEAN: case GGML_OP_MEAN:
ggml_vk_mean(ctx, compute_ctx, src0, node); ggml_vk_mean(ctx, compute_ctx, src0, node);
@ -13768,6 +13804,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_CUMSUM:
{
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
}
return false;
}
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
@ -14418,6 +14463,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SUM_ROWS) { } else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CUMSUM) {
tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_MEAN) { } else if (tensor->op == GGML_OP_MEAN) {
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_ARGMAX) { } else if (tensor->op == GGML_OP_ARGMAX) {

View File

@ -0,0 +1,69 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
shared FLOAT_TYPE last_sum;
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
uint subgroup_id = tid / SUBGROUP_SIZE;
if (tid == 0) {
last_sum = 0;
}
uint col = tid;
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
for (int i = 0; i < num_iter; ++i) {
FLOAT_TYPE v = 0;
if (col < p.n_cols) {
v = FLOAT_TYPE(data_a[src_idx + col]);
}
v = subgroupInclusiveAdd(v);
// Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
partial[subgroup_id] = v;
}
barrier();
for (int j = 0; j < subgroup_id; ++j) {
v += partial[j];
}
v += last_sum;
barrier();
if (tid == BLOCK_SIZE - 1) {
last_sum = v;
}
if (col < p.n_cols) {
data_d[dst_idx + col] = D_TYPE(v);
}
col += BLOCK_SIZE;
}
}

View File

@ -1,6 +1,7 @@
#version 450 #version 450
#include "types.glsl" #include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_control_flow_attributes : enable
@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32; layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
shared FLOAT_TYPE tmp[BLOCK_SIZE]; shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() { void main() {

View File

@ -0,0 +1,25 @@
// vk_op_sum_rows_push_constants
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}

View File

@ -916,6 +916,7 @@ void process_shaders() {
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
for (std::string dim_str : {"", "_3d"}) { for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) { for (bool bda : {false, true}) {

View File

@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ARANGE", "ARANGE",
"TIMESTEP_EMBEDDING", "TIMESTEP_EMBEDDING",
"ARGSORT", "ARGSORT",
"TOP_K",
"LEAKY_RELU", "LEAKY_RELU",
"TRI", "TRI",
"FILL", "FILL",
@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU", "GLU",
}; };
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"arange(start, stop, step)", "arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)", "timestep_embedding(timesteps, dim, max_period)",
"argsort(x)", "argsort(x)",
"top_k(x)",
"leaky_relu(x)", "leaky_relu(x)",
"tri(x)", "tri(x)",
"fill(x, c)", "fill(x, c)",
@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)", "glu(x)",
}; };
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll(
return result; return result;
} }
// ggml_arange
struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step) {
GGML_ASSERT(stop > start);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
ggml_set_op_params_f32(result, 0, start);
ggml_set_op_params_f32(result, 1, stop);
ggml_set_op_params_f32(result, 2, step);
result->op = GGML_OP_ARANGE;
return result;
}
// ggml_timestep_embedding // ggml_timestep_embedding
struct ggml_tensor * ggml_timestep_embedding( struct ggml_tensor * ggml_timestep_embedding(
@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort(
struct ggml_tensor * a, struct ggml_tensor * a,
enum ggml_sort_order order) { enum ggml_sort_order order) {
GGML_ASSERT(a->ne[0] <= INT32_MAX); GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
ggml_set_op_params_i32(result, 0, (int32_t) order); ggml_set_op_params_i32(result, 0, (int32_t) order);
@ -5149,6 +5130,24 @@ struct ggml_tensor * ggml_argsort(
return result; return result;
} }
// ggml_argsort_top_k
struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k) {
GGML_ASSERT(a->ne[0] >= k);
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
result = ggml_view_4d(ctx, result,
k, result->ne[1], result->ne[2], result->ne[3],
result->nb[1], result->nb[2], result->nb[3],
0);
return result;
}
// ggml_top_k // ggml_top_k
struct ggml_tensor * ggml_top_k( struct ggml_tensor * ggml_top_k(
@ -5157,12 +5156,32 @@ struct ggml_tensor * ggml_top_k(
int k) { int k) {
GGML_ASSERT(a->ne[0] >= k); GGML_ASSERT(a->ne[0] >= k);
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
result = ggml_view_4d(ctx, result, result->op = GGML_OP_TOP_K;
k, result->ne[1], result->ne[2], result->ne[3], result->src[0] = a;
result->nb[1], result->nb[2], result->nb[3],
0); return result;
}
// ggml_arange
struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step) {
GGML_ASSERT(stop > start);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
ggml_set_op_params_f32(result, 0, start);
ggml_set_op_params_f32(result, 1, stop);
ggml_set_op_params_f32(result, 2, step);
result->op = GGML_OP_ARANGE;
return result; return result;
} }

View File

@ -25,6 +25,20 @@ class Keys:
ALIGNMENT = "general.alignment" ALIGNMENT = "general.alignment"
FILE_TYPE = "general.file_type" FILE_TYPE = "general.file_type"
# Recommended Sampler Parameters
SAMPLING_SEQUENCE = "general.sampling.sequence"
SAMPLING_TOP_K = "general.sampling.top_k"
SAMPLING_TOP_P = "general.sampling.top_p"
SAMPLING_MIN_P = "general.sampling.min_p"
SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability"
SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold"
SAMPLING_TEMP = "general.sampling.temp"
SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n"
SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat"
SAMPLING_MIROSTAT = "general.sampling.mirostat"
SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau"
SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta"
# Authorship Metadata # Authorship Metadata
NAME = "general.name" NAME = "general.name"
AUTHOR = "general.author" AUTHOR = "general.author"

View File

@ -4,6 +4,7 @@ import logging
import os import os
import shutil import shutil
import struct import struct
import sys
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
@ -372,8 +373,10 @@ class GGUFWriter:
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None, raw_dtype: GGMLQuantizationType | None = None,
) -> None: ) -> None:
if self.endianess == GGUFEndian.BIG: if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
tensor.byteswap(inplace=True) (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
if self.use_temp_file and self.temp_file is None: if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
fp.seek(0) fp.seek(0)
@ -399,8 +402,10 @@ class GGUFWriter:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None assert self.fout is not None
if self.endianess == GGUFEndian.BIG: if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
tensor.byteswap(inplace=True) (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
file_id = -1 file_id = -1
for i, tensors in enumerate(self.tensors): for i, tensors in enumerate(self.tensors):
@ -496,6 +501,42 @@ class GGUFWriter:
def add_file_type(self, ftype: int) -> None: def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype) self.add_uint32(Keys.General.FILE_TYPE, ftype)
def add_sampling_sequence(self, sequence: str) -> None:
self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
def add_sampling_top_k(self, top_k: int) -> None:
self.add_int32(Keys.General.SAMPLING_TOP_K, top_k)
def add_sampling_top_p(self, top_p: float) -> None:
self.add_float32(Keys.General.SAMPLING_TOP_P, top_p)
def add_sampling_min_p(self, min_p: float) -> None:
self.add_float32(Keys.General.SAMPLING_MIN_P, min_p)
def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
def add_sampling_temp(self, temp: float) -> None:
self.add_float32(Keys.General.SAMPLING_TEMP, temp)
def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
def add_sampling_mirostat(self, mirostat: int) -> None:
self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat)
def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
def add_name(self, name: str) -> None: def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name) self.add_string(Keys.General.NAME, name)

View File

@ -17,6 +17,20 @@ logger = logging.getLogger("metadata")
@dataclass @dataclass
class Metadata: class Metadata:
# Recommended Sampler Parameters to be written to GGUF KV Store
sampling_sequence: Optional[str] = None
sampling_top_k: Optional[int] = None
sampling_top_p: Optional[float] = None
sampling_min_p: Optional[float] = None
sampling_xtc_probability: Optional[float] = None
sampling_xtc_threshold: Optional[float] = None
sampling_temp: Optional[float] = None
sampling_penalty_last_n: Optional[int] = None
sampling_penalty_repeat: Optional[float] = None
sampling_mirostat: Optional[int] = None
sampling_mirostat_tau: Optional[float] = None
sampling_mirostat_eta: Optional[float] = None
# Authorship Metadata to be written to GGUF KV Store # Authorship Metadata to be written to GGUF KV Store
name: Optional[str] = None name: Optional[str] = None
author: Optional[str] = None author: Optional[str] = None
@ -54,15 +68,43 @@ class Metadata:
model_card = Metadata.load_model_card(model_path) model_card = Metadata.load_model_card(model_path)
hf_params = Metadata.load_hf_parameters(model_path) hf_params = Metadata.load_hf_parameters(model_path)
gen_config = Metadata.load_generation_config(model_path)
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
# heuristics # heuristics
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
if gen_config:
metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence)
metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k)
metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p)
metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p)
metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold)
metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp)
metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n)
metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat)
metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat)
metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau)
metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta)
# Metadata Override File Provided # Metadata Override File Provided
# This is based on LLM_KV_NAMES mapping in llama.cpp # This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path) metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence)
metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k)
metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p)
metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p)
metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold)
metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp)
metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n)
metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat)
metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat)
metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau)
metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta)
metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
@ -172,6 +214,23 @@ class Metadata:
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
@staticmethod
def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
if model_path is None or not model_path.is_dir():
return {}
generation_config_path = model_path / "generation_config.json"
if not generation_config_path.is_file():
return {}
try:
with open(generation_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
# not all models have valid generation_config.json
return {}
@staticmethod @staticmethod
def id_to_title(string): def id_to_title(string):
# Convert capitalization into title form unless acronym or version number # Convert capitalization into title form unless acronym or version number
@ -546,6 +605,32 @@ class Metadata:
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
assert self.name is not None assert self.name is not None
if self.sampling_sequence is not None:
gguf_writer.add_sampling_sequence(self.sampling_sequence)
if self.sampling_top_k is not None:
gguf_writer.add_sampling_top_k(self.sampling_top_k)
if self.sampling_top_p is not None:
gguf_writer.add_sampling_top_p(self.sampling_top_p)
if self.sampling_min_p is not None:
gguf_writer.add_sampling_min_p(self.sampling_min_p)
if self.sampling_xtc_probability is not None:
gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
if self.sampling_xtc_threshold is not None:
gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
if self.sampling_temp is not None:
gguf_writer.add_sampling_temp(self.sampling_temp)
if self.sampling_penalty_last_n is not None:
gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
if self.sampling_penalty_repeat is not None:
gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
if self.sampling_mirostat is not None:
gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
if self.sampling_mirostat_tau is not None:
gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
if self.sampling_mirostat_eta is not None:
gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
gguf_writer.add_name(self.name) gguf_writer.add_name(self.name)
if self.author is not None: if self.author is not None:

View File

@ -246,6 +246,21 @@ extern "C" {
LLAMA_KV_OVERRIDE_TYPE_STR, LLAMA_KV_OVERRIDE_TYPE_STR,
}; };
enum llama_model_meta_key {
LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE,
LLAMA_MODEL_META_KEY_SAMPLING_TOP_K,
LLAMA_MODEL_META_KEY_SAMPLING_TOP_P,
LLAMA_MODEL_META_KEY_SAMPLING_MIN_P,
LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY,
LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD,
LLAMA_MODEL_META_KEY_SAMPLING_TEMP,
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N,
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT,
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT,
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU,
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA,
};
struct llama_model_kv_override { struct llama_model_kv_override {
enum llama_model_kv_override_type tag; enum llama_model_kv_override_type tag;
@ -518,6 +533,9 @@ extern "C" {
// Get the number of metadata key/value pairs // Get the number of metadata key/value pairs
LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model); LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
// Get sampling metadata key name. Returns nullptr if the key is invalid
LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key);
// Get metadata key name by index // Get metadata key name by index
LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);

View File

@ -16,7 +16,7 @@ vendor = {
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h", "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h",
"https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h",
} }

View File

@ -114,19 +114,31 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
}; };
static const std::map<llm_kv, const char *> LLM_KV_NAMES = { static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_TYPE, "general.type" }, { LLM_KV_GENERAL_TYPE, "general.type" },
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
{ LLM_KV_GENERAL_NAME, "general.name" }, { LLM_KV_GENERAL_SAMPLING_SEQUENCE, "general.sampling.sequence" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" }, { LLM_KV_GENERAL_SAMPLING_TOP_K, "general.sampling.top_k" },
{ LLM_KV_GENERAL_VERSION, "general.version" }, { LLM_KV_GENERAL_SAMPLING_TOP_P, "general.sampling.top_p" },
{ LLM_KV_GENERAL_URL, "general.url" }, { LLM_KV_GENERAL_SAMPLING_MIN_P, "general.sampling.min_p" },
{ LLM_KV_GENERAL_DESCRIPTION, "general.description" }, { LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability" },
{ LLM_KV_GENERAL_LICENSE, "general.license" }, { LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, "general.sampling.xtc_threshold" },
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SAMPLING_TEMP, "general.sampling.temp" },
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, { LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, "general.sampling.penalty_last_n" },
{ LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, "general.sampling.penalty_repeat" },
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT, "general.sampling.mirostat" },
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, "general.sampling.mirostat_tau" },
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, "general.sampling.mirostat_eta" },
{ LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" },
{ LLM_KV_GENERAL_URL, "general.url" },
{ LLM_KV_GENERAL_DESCRIPTION, "general.description" },
{ LLM_KV_GENERAL_LICENSE, "general.license" },
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" },

View File

@ -123,6 +123,18 @@ enum llm_kv {
LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_QUANTIZATION_VERSION,
LLM_KV_GENERAL_ALIGNMENT, LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_FILE_TYPE, LLM_KV_GENERAL_FILE_TYPE,
LLM_KV_GENERAL_SAMPLING_SEQUENCE,
LLM_KV_GENERAL_SAMPLING_TOP_K,
LLM_KV_GENERAL_SAMPLING_TOP_P,
LLM_KV_GENERAL_SAMPLING_MIN_P,
LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY,
LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD,
LLM_KV_GENERAL_SAMPLING_TEMP,
LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N,
LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT,
LLM_KV_GENERAL_SAMPLING_MIROSTAT,
LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU,
LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA,
LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION, LLM_KV_GENERAL_VERSION,

View File

@ -1248,7 +1248,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// make the outputs have the same order they had in the user-provided batch // make the outputs have the same order they had in the user-provided batch
// note: this is mostly relevant for recurrent models atm // note: this is mostly relevant for recurrent models atm
if (!sorted_output) { if (!sorted_output && n_outputs > 1) {
GGML_ASSERT((size_t) n_outputs == out_ids.size()); GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps? // TODO: is there something more efficient which also minimizes swaps?

View File

@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
// organize experts into n_expert_groups // organize experts into n_expert_groups
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens] ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens] group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
// get top n_group_used expert groups // get top n_group_used expert groups
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens] group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens] group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
cb(expert_groups, "ffn_moe_group_topk", il); cb(expert_groups, "ffn_moe_group_topk", il);
// mask out the other groups // mask out the other groups
@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
} }
// select experts // select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il); cb(selected_experts, "ffn_moe_topk", il);

View File

@ -7687,6 +7687,24 @@ int32_t llama_model_meta_count(const llama_model * model) {
return (int)model->gguf_kv.size(); return (int)model->gguf_kv.size();
} }
const char * llama_model_meta_key_str(llama_model_meta_key key) {
switch (key) {
case LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE: return "general.sampling.sequence";
case LLAMA_MODEL_META_KEY_SAMPLING_TOP_K: return "general.sampling.top_k";
case LLAMA_MODEL_META_KEY_SAMPLING_TOP_P: return "general.sampling.top_p";
case LLAMA_MODEL_META_KEY_SAMPLING_MIN_P: return "general.sampling.min_p";
case LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY: return "general.sampling.xtc_probability";
case LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD: return "general.sampling.xtc_threshold";
case LLAMA_MODEL_META_KEY_SAMPLING_TEMP: return "general.sampling.temp";
case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N: return "general.sampling.penalty_last_n";
case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT: return "general.sampling.penalty_repeat";
case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT: return "general.sampling.mirostat";
case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU: return "general.sampling.mirostat_tau";
case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA: return "general.sampling.mirostat_eta";
default: return nullptr;
}
}
int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) { int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)model->gguf_kv.size()) { if (i < 0 || i >= (int)model->gguf_kv.size()) {
if (buf_size > 0) { if (buf_size > 0) {

View File

@ -39,6 +39,7 @@
#include <string_view> #include <string_view>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <unordered_map>
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
size_t nels = ggml_nelements(tensor); size_t nels = ggml_nelements(tensor);
@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) {
return mse_a_b / mse_a_0; return mse_a_b / mse_a_0;
} }
// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
static double jdst(const int32_t * a, const int32_t * b, size_t n) {
std::unordered_map<int32_t, size_t> set_a;
std::unordered_map<int32_t, size_t> set_b;
for (size_t i = 0; i < n; ++i) {
set_a[a[i]]++;
set_b[b[i]]++;
}
size_t diff = 0;
for (const auto & p : set_a) {
const int64_t na = p.second;
const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0;
diff += std::abs(na - nb);
}
for (const auto & p : set_b) {
if (set_a.find(p.first) == set_a.end()) {
diff += p.second;
}
}
return (double) diff / (2*n);
}
// maximum absolute asymmetry between a and b // maximum absolute asymmetry between a and b
// asymmetry: (a - b) / (a + b) // asymmetry: (a - b) / (a + b)
// This is more stable than relative error if one of the values fluctuates towards zero. // This is more stable than relative error if one of the values fluctuates towards zero.
@ -1051,6 +1080,14 @@ struct test_case {
return 1e-4; return 1e-4;
} }
virtual double max_err() {
return max_nmse_err();
}
virtual double err(const float * a, const float * b, size_t n) {
return nmse(a, b, n);
}
virtual float grad_eps() { virtual float grad_eps() {
return 1e-1f; return 1e-1f;
} }
@ -1257,16 +1294,16 @@ struct test_case {
// compare // compare
struct callback_userdata { struct callback_userdata {
bool ok; bool ok;
double max_err; test_case * tc;
ggml_backend_t backend1; ggml_backend_t backend1;
ggml_backend_t backend2; ggml_backend_t backend2;
}; };
callback_userdata ud { callback_userdata ud {
true, true,
max_nmse_err(), this,
backend1, backend1,
backend2 backend2,
}; };
auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
@ -1314,9 +1351,9 @@ struct test_case {
} }
} }
double err = nmse(f1.data(), f2.data(), f1.size()); double err = ud->tc->err(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) { if (err > ud->tc->max_err()) {
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err());
//for (int i = 0; i < (int) f1.size(); i++) { //for (int i = 0; i < (int) f1.size(); i++) {
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
//} //}
@ -4943,7 +4980,71 @@ struct test_argsort : public test_case {
} }
}; };
struct test_topk_moe: public test_case { // GGML_OP_TOP_K
struct test_top_k : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const int k;
std::string vars() override {
return VARS_TO_STR3(type, ne, k);
}
test_top_k(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {16, 10, 10, 10},
int k = 4)
: type(type), ne(ne), k(k) {}
double max_err() override {
return 0.0;
}
double err(const float * a, const float * b, size_t n) override {
std::vector<int32_t> ia(n);
std::vector<int32_t> ib(n);
double diff = 0.0f;
for (size_t i = 0; i < n; i++) {
ia[i] = (int32_t) a[i];
ib[i] = (int32_t) b[i];
// penalize the result if the data is not integer valued
diff += std::fabs(a[i] - ia[i]);
diff += std::fabs(b[i] - ib[i]);
}
return diff + jdst(ia.data(), ib.data(), n);
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * out = ggml_top_k(ctx, a, k);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// initialize with unique values to avoid ties
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<float> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
}
}
}
};
struct test_topk_moe : public test_case {
const std::array<int64_t, 4> ne; const std::array<int64_t, 4> ne;
const int n_expert_used; const int n_expert_used;
const bool with_norm; const bool with_norm;
@ -4976,7 +5077,7 @@ struct test_topk_moe: public test_case {
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
@ -7534,6 +7635,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
} }
for (int k : {1, 2, 3, 7, 15}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k));
}
// exhaustive top_k tests
//for (int i = 1; i < 9999; ++i) {
// test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
//}
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
@ -7859,6 +7977,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
} }
} }
// Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012
test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
for (int kv : { 4096, 8192, 16384, }) { for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) { for (int hs : { 64, 128, }) {
for (int nr : { 1, 4, }) { for (int nr : { 1, 4, }) {
@ -7911,6 +8032,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
} }
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
return test_cases; return test_cases;
} }

Binary file not shown.

View File

@ -8,6 +8,7 @@
import rehypeKatex from 'rehype-katex'; import rehypeKatex from 'rehype-katex';
import rehypeStringify from 'rehype-stringify'; import rehypeStringify from 'rehype-stringify';
import { copyCodeToClipboard } from '$lib/utils/copy'; import { copyCodeToClipboard } from '$lib/utils/copy';
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer';
import { preprocessLaTeX } from '$lib/utils/latex-protection'; import { preprocessLaTeX } from '$lib/utils/latex-protection';
import { browser } from '$app/environment'; import { browser } from '$app/environment';
import '$styles/katex-custom.scss'; import '$styles/katex-custom.scss';
@ -60,6 +61,7 @@
.use(remarkRehype) // Convert Markdown AST to rehype .use(remarkRehype) // Convert Markdown AST to rehype
.use(rehypeKatex) // Render math using KaTeX .use(rehypeKatex) // Render math using KaTeX
.use(rehypeHighlight) // Add syntax highlighting .use(rehypeHighlight) // Add syntax highlighting
.use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables
.use(rehypeStringify); // Convert to HTML string .use(rehypeStringify); // Convert to HTML string
}); });

View File

@ -0,0 +1,20 @@
/**
* Matches <br>, <br/>, <br /> tags (case-insensitive).
* Used to detect line breaks in table cell text content.
*/
export const BR_PATTERN = /<br\s*\/?\s*>/gi;
/**
* Matches a complete <ul>...</ul> block.
* Captures the inner content (group 1) for further <li> extraction.
* Case-insensitive, allows multiline content.
*/
export const LIST_PATTERN = /^<ul>([\s\S]*)<\/ul>$/i;
/**
* Matches individual <li>...</li> elements within a list.
* Captures the inner content (group 1) of each list item.
* Non-greedy to handle multiple consecutive items.
* Case-insensitive, allows multiline content.
*/
export const LI_PATTERN = /<li>([\s\S]*?)<\/li>/gi;

View File

@ -0,0 +1,181 @@
/**
* Rehype plugin to restore limited HTML elements inside Markdown table cells.
*
* ## Problem
* The remark/rehype pipeline neutralizes inline HTML as literal text
* (remarkLiteralHtml) so that XML/HTML snippets in LLM responses display
* as-is instead of being rendered. This causes <br> and <ul> markup in
* table cells to show as plain text.
*
* ## Solution
* This plugin traverses the HAST post-conversion, parses whitelisted HTML
* patterns from text nodes, and replaces them with actual HAST element nodes
* that will be rendered as real HTML.
*
* ## Supported HTML
* - `<br>` / `<br/>` / `<br />` - Line breaks (inline)
* - `<ul><li>...</li></ul>` - Unordered lists (block)
*
* ## Key Implementation Details
*
* ### 1. Sibling Combination (Critical)
* The Markdown pipeline may fragment content across multiple text nodes and `<br>`
* elements. For example, `<ul><li>a</li></ul>` might arrive as:
* - Text: `"<ul>"`
* - Element: `<br>`
* - Text: `"<li>a</li></ul>"`
*
* We must combine consecutive text nodes and `<br>` elements into a single string
* before attempting to parse list markup. Without this, list detection fails.
*
* ### 2. visitParents for Deep Traversal
* Table cell content may be wrapped in intermediate elements (e.g., `<p>` tags).
* Using `visitParents` instead of direct child iteration ensures we find text
* nodes at any depth within the cell.
*
* ### 3. Reference Comparison for No-Op Detection
* When checking if `<br>` expansion changed anything, we compare:
* `expanded.length !== 1 || expanded[0] !== textNode`
*
* This catches both cases:
* - Multiple nodes created (text was split)
* - Single NEW node created (original had only `<br>`, now it's an element)
*
* A simple `length > 1` check would miss the single `<br>` case.
*
* ### 4. Strict List Validation
* `parseList()` rejects malformed markup by checking for garbage text between
* `<li>` elements. This prevents creating broken DOM from partial matches like
* `<ul>garbage<li>a</li></ul>`.
*
* ### 5. Newline Substitution for `<br>` in Combined String
* When combining siblings, existing `<br>` elements become `\n` in the combined
* string. This allows list content to span visual lines while still being parsed
* as a single unit.
*
* @example
* // Input Markdown:
* // | Feature | Notes |
* // |---------|-------|
* // | Multi-line | First<br>Second |
* // | List | <ul><li>A</li><li>B</li></ul> |
* //
* // Without this plugin: <br> and <ul> render as literal text
* // With this plugin: <br> becomes line break, <ul> becomes actual list
*/
import type { Plugin } from 'unified';
import type { Element, ElementContent, Root, Text } from 'hast';
import { visit } from 'unist-util-visit';
import { visitParents } from 'unist-util-visit-parents';
import { BR_PATTERN, LIST_PATTERN, LI_PATTERN } from '$lib/constants/table-html-restorer';
/**
* Expands text containing `<br>` tags into an array of text nodes and br elements.
*/
function expandBrTags(value: string): ElementContent[] {
const matches = [...value.matchAll(BR_PATTERN)];
if (!matches.length) return [{ type: 'text', value } as Text];
const result: ElementContent[] = [];
let cursor = 0;
for (const m of matches) {
if (m.index! > cursor) {
result.push({ type: 'text', value: value.slice(cursor, m.index) } as Text);
}
result.push({ type: 'element', tagName: 'br', properties: {}, children: [] } as Element);
cursor = m.index! + m[0].length;
}
if (cursor < value.length) {
result.push({ type: 'text', value: value.slice(cursor) } as Text);
}
return result;
}
/**
* Parses a `<ul><li>...</li></ul>` string into a HAST element.
* Returns null if the markup is malformed or contains unexpected content.
*/
function parseList(value: string): Element | null {
const match = value.trim().match(LIST_PATTERN);
if (!match) return null;
const body = match[1];
const items: ElementContent[] = [];
let cursor = 0;
for (const liMatch of body.matchAll(LI_PATTERN)) {
// Reject if there's non-whitespace between list items
if (body.slice(cursor, liMatch.index!).trim()) return null;
items.push({
type: 'element',
tagName: 'li',
properties: {},
children: expandBrTags(liMatch[1] ?? '')
} as Element);
cursor = liMatch.index! + liMatch[0].length;
}
// Reject if no items found or trailing garbage exists
if (!items.length || body.slice(cursor).trim()) return null;
return { type: 'element', tagName: 'ul', properties: {}, children: items } as Element;
}
/**
* Processes a single table cell, restoring HTML elements from text content.
*/
function processCell(cell: Element) {
visitParents(cell, 'text', (textNode: Text, ancestors) => {
const parent = ancestors[ancestors.length - 1];
if (!parent || parent.type !== 'element') return;
const parentEl = parent as Element;
const siblings = parentEl.children as ElementContent[];
const startIndex = siblings.indexOf(textNode as ElementContent);
if (startIndex === -1) return;
// Combine consecutive text nodes and <br> elements into one string
let combined = '';
let endIndex = startIndex;
for (let i = startIndex; i < siblings.length; i++) {
const sib = siblings[i];
if (sib.type === 'text') {
combined += (sib as Text).value;
endIndex = i;
} else if (sib.type === 'element' && (sib as Element).tagName === 'br') {
combined += '\n';
endIndex = i;
} else {
break;
}
}
// Try parsing as list first (replaces entire combined range)
const list = parseList(combined);
if (list) {
siblings.splice(startIndex, endIndex - startIndex + 1, list);
return;
}
// Otherwise, just expand <br> tags in this text node
const expanded = expandBrTags(textNode.value);
if (expanded.length !== 1 || expanded[0] !== textNode) {
siblings.splice(startIndex, 1, ...expanded);
}
});
}
export const rehypeRestoreTableHtml: Plugin<[], Root> = () => (tree) => {
visit(tree, 'element', (node: Element) => {
if (node.tagName === 'td' || node.tagName === 'th') {
processCell(node);
}
});
};

View File

@ -31,13 +31,16 @@ if (LLAMA_BUILD_BORINGSSL)
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}") message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")
include(FetchContent) set(BORINGSSL_ARGS
FetchContent_Declare(
boringssl
GIT_REPOSITORY ${BORINGSSL_GIT} GIT_REPOSITORY ${BORINGSSL_GIT}
GIT_TAG ${BORINGSSL_VERSION} GIT_TAG ${BORINGSSL_VERSION}
PATCH_COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_SOURCE_DIR}/patch-boringssl.cmake"
) )
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL)
endif()
include(FetchContent)
FetchContent_Declare(boringssl ${BORINGSSL_ARGS})
set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(SAVED_BUILD_TESTING ${BUILD_TESTING}) set(SAVED_BUILD_TESTING ${BUILD_TESTING})
@ -45,7 +48,15 @@ if (LLAMA_BUILD_BORINGSSL)
set(BUILD_SHARED_LIBS OFF) set(BUILD_SHARED_LIBS OFF)
set(BUILD_TESTING OFF) set(BUILD_TESTING OFF)
FetchContent_MakeAvailable(boringssl) if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
FetchContent_MakeAvailable(boringssl)
else()
FetchContent_GetProperties(boringssl)
if(NOT boringssl_POPULATED)
FetchContent_Populate(boringssl)
add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()
endif()
set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS})
set(BUILD_TESTING ${SAVED_BUILD_TESTING}) set(BUILD_TESTING ${SAVED_BUILD_TESTING})

View File

@ -1087,22 +1087,30 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
// Fallback implementation using thread-based timeout for other Unix systems // Fallback implementation using thread-based timeout for other Unix systems
struct GetAddrInfoState { struct GetAddrInfoState {
~GetAddrInfoState() {
if (info) { freeaddrinfo(info); }
}
std::mutex mutex; std::mutex mutex;
std::condition_variable result_cv; std::condition_variable result_cv;
bool completed = false; bool completed = false;
int result = EAI_SYSTEM; int result = EAI_SYSTEM;
std::string node = node; std::string node;
std::string service = service; std::string service;
struct addrinfo hints = hints; struct addrinfo hints;
struct addrinfo *info = nullptr; struct addrinfo *info = nullptr;
}; };
// Allocate on the heap, so the resolver thread can keep using the data. // Allocate on the heap, so the resolver thread can keep using the data.
auto state = std::make_shared<GetAddrInfoState>(); auto state = std::make_shared<GetAddrInfoState>();
state->node = node;
state->service = service;
state->hints = *hints;
std::thread resolve_thread([=]() { std::thread resolve_thread([state]() {
auto thread_result = getaddrinfo( auto thread_result =
state->node.c_str(), state->service.c_str(), hints, &state->info); getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints,
&state->info);
std::lock_guard<std::mutex> lock(state->mutex); std::lock_guard<std::mutex> lock(state->mutex);
state->result = thread_result; state->result = thread_result;
@ -1120,6 +1128,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
// Operation completed within timeout // Operation completed within timeout
resolve_thread.join(); resolve_thread.join();
*res = state->info; *res = state->info;
state->info = nullptr; // Pass ownership to caller
return state->result; return state->result;
} else { } else {
// Timeout occurred // Timeout occurred
@ -4970,7 +4979,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection,
if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
// Prepare additional headers // Prepare additional headers
if (close_connection || req.get_header_value("Connection") == "close") { if (close_connection || req.get_header_value("Connection") == "close" ||
400 <= res.status) { // Don't leave connections open after errors
res.set_header("Connection", "close"); res.set_header("Connection", "close");
} else { } else {
std::string s = "timeout="; std::string s = "timeout=";
@ -5173,7 +5183,11 @@ bool Server::read_content_core(
size_t /*len*/) { return receiver(buf, n); }; size_t /*len*/) { return receiver(buf, n); };
} }
if (req.method == "DELETE" && !req.has_header("Content-Length")) { // RFC 7230 Section 3.3.3: If this is a request message and none of the above
// are true (no Transfer-Encoding and no Content-Length), then the message
// body length is zero (no message body is present).
if (!req.has_header("Content-Length") &&
!detail::is_chunked_transfer_encoding(req.headers)) {
return true; return true;
} }
@ -5681,8 +5695,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
// Check if the request URI doesn't exceed the limit // Check if the request URI doesn't exceed the limit
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
Headers dummy;
detail::read_headers(strm, dummy);
res.status = StatusCode::UriTooLong_414; res.status = StatusCode::UriTooLong_414;
output_error_log(Error::ExceedUriMaxLength, &req); output_error_log(Error::ExceedUriMaxLength, &req);
return write_response(strm, close_connection, req, res); return write_response(strm, close_connection, req, res);
@ -6666,11 +6678,13 @@ bool ClientImpl::write_request(Stream &strm, Request &req,
return true; return true;
} }
std::unique_ptr<Response> ClientImpl::send_with_content_provider( std::unique_ptr<Response>
ClientImpl::send_with_content_provider_and_receiver(
Request &req, const char *body, size_t content_length, Request &req, const char *body, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length, ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, Error &error) { const std::string &content_type, ContentReceiver content_receiver,
Error &error) {
if (!content_type.empty()) { req.set_header("Content-Type", content_type); } if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
#ifdef CPPHTTPLIB_ZLIB_SUPPORT #ifdef CPPHTTPLIB_ZLIB_SUPPORT
@ -6743,15 +6757,24 @@ std::unique_ptr<Response> ClientImpl::send_with_content_provider(
} }
} }
if (content_receiver) {
req.content_receiver =
[content_receiver](const char *data, size_t data_length,
size_t /*offset*/, size_t /*total_length*/) {
return content_receiver(data, data_length);
};
}
auto res = detail::make_unique<Response>(); auto res = detail::make_unique<Response>();
return send(req, *res, error) ? std::move(res) : nullptr; return send(req, *res, error) ? std::move(res) : nullptr;
} }
Result ClientImpl::send_with_content_provider( Result ClientImpl::send_with_content_provider_and_receiver(
const std::string &method, const std::string &path, const Headers &headers, const std::string &method, const std::string &path, const Headers &headers,
const char *body, size_t content_length, ContentProvider content_provider, const char *body, size_t content_length, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length, ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, UploadProgress progress) { const std::string &content_type, ContentReceiver content_receiver,
UploadProgress progress) {
Request req; Request req;
req.method = method; req.method = method;
req.headers = headers; req.headers = headers;
@ -6763,9 +6786,10 @@ Result ClientImpl::send_with_content_provider(
auto error = Error::Success; auto error = Error::Success;
auto res = send_with_content_provider( auto res = send_with_content_provider_and_receiver(
req, body, content_length, std::move(content_provider), req, body, content_length, std::move(content_provider),
std::move(content_provider_without_length), content_type, error); std::move(content_provider_without_length), content_type,
std::move(content_receiver), error);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, return Result{std::move(res), error, std::move(req.headers), last_ssl_error_,
@ -7094,6 +7118,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length,
content_type, progress); content_type, progress);
} }
Result ClientImpl::Post(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Post(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Post(const std::string &path, Result ClientImpl::Post(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -7102,6 +7135,15 @@ Result ClientImpl::Post(const std::string &path,
progress); progress);
} }
Result ClientImpl::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Post(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
const Params &params) { const Params &params) {
auto query = detail::params_to_query_str(params); auto query = detail::params_to_query_str(params);
@ -7142,17 +7184,18 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
const char *body, size_t content_length, const char *body, size_t content_length,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("POST", path, headers, body, content_length, return send_with_content_provider_and_receiver(
nullptr, nullptr, content_type, progress); "POST", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
const std::string &body, const std::string &body,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("POST", path, headers, body.data(), return send_with_content_provider_and_receiver(
body.size(), nullptr, nullptr, content_type, "POST", path, headers, body.data(), body.size(), nullptr, nullptr,
progress); content_type, nullptr, progress);
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7160,18 +7203,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProvider content_provider, ContentProvider content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("POST", path, headers, nullptr, return send_with_content_provider_and_receiver(
content_length, std::move(content_provider), "POST", path, headers, nullptr, content_length,
nullptr, content_type, progress); std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), std::move(progress));
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, return send_with_content_provider_and_receiver(
std::move(content_provider), content_type, "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
progress); content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), std::move(progress));
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7181,10 +7246,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary(); const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type = const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary); detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider( return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr, "POST", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items), get_multipart_content_provider(boundary, items, provider_items),
content_type, progress); content_type, nullptr, progress);
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7246,6 +7311,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length,
content_type, progress); content_type, progress);
} }
Result ClientImpl::Put(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Put(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, Result ClientImpl::Put(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -7254,6 +7328,15 @@ Result ClientImpl::Put(const std::string &path,
progress); progress);
} }
Result ClientImpl::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Put(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
const Params &params) { const Params &params) {
auto query = detail::params_to_query_str(params); auto query = detail::params_to_query_str(params);
@ -7294,17 +7377,18 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
const char *body, size_t content_length, const char *body, size_t content_length,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, body, content_length, return send_with_content_provider_and_receiver(
nullptr, nullptr, content_type, progress); "PUT", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
const std::string &body, const std::string &body,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, body.data(), return send_with_content_provider_and_receiver(
body.size(), nullptr, nullptr, content_type, "PUT", path, headers, body.data(), body.size(), nullptr, nullptr,
progress); content_type, nullptr, progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7312,18 +7396,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProvider content_provider, ContentProvider content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, nullptr, return send_with_content_provider_and_receiver(
content_length, std::move(content_provider), "PUT", path, headers, nullptr, content_length,
nullptr, content_type, progress); std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, return send_with_content_provider_and_receiver(
std::move(content_provider), content_type, "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
progress); content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7333,10 +7439,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary(); const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type = const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary); detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider( return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr, "PUT", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items), get_multipart_content_provider(boundary, items, provider_items),
content_type, progress); content_type, nullptr, progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7400,6 +7506,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length,
content_type, progress); content_type, progress);
} }
Result ClientImpl::Patch(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Patch(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, Result ClientImpl::Patch(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -7408,6 +7523,15 @@ Result ClientImpl::Patch(const std::string &path,
progress); progress);
} }
Result ClientImpl::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Patch(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const Params &params) { const Params &params) {
auto query = detail::params_to_query_str(params); auto query = detail::params_to_query_str(params);
@ -7448,18 +7572,18 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const char *body, size_t content_length, const char *body, size_t content_length,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, body, return send_with_content_provider_and_receiver(
content_length, nullptr, nullptr, "PATCH", path, headers, body, content_length, nullptr, nullptr,
content_type, progress); content_type, nullptr, progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const std::string &body, const std::string &body,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, body.data(), return send_with_content_provider_and_receiver(
body.size(), nullptr, nullptr, content_type, "PATCH", path, headers, body.data(), body.size(), nullptr, nullptr,
progress); content_type, nullptr, progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -7467,18 +7591,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProvider content_provider, ContentProvider content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, nullptr, return send_with_content_provider_and_receiver(
content_length, std::move(content_provider), "PATCH", path, headers, nullptr, content_length,
nullptr, content_type, progress); std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, return send_with_content_provider_and_receiver(
std::move(content_provider), content_type, "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
progress); content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -7488,10 +7634,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary(); const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type = const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary); detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider( return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr, "PATCH", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items), get_multipart_content_provider(boundary, items, provider_items),
content_type, progress); content_type, nullptr, progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -8883,12 +9029,28 @@ Result Client::Post(const std::string &path, size_t content_length,
return cli_->Post(path, content_length, std::move(content_provider), return cli_->Post(path, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Post(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Post(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, Result Client::Post(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return cli_->Post(path, std::move(content_provider), content_type, progress); return cli_->Post(path, std::move(content_provider), content_type, progress);
} }
Result Client::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Post(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Headers &headers, Result Client::Post(const std::string &path, const Headers &headers,
size_t content_length, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
@ -8897,6 +9059,15 @@ Result Client::Post(const std::string &path, const Headers &headers,
return cli_->Post(path, headers, content_length, std::move(content_provider), return cli_->Post(path, headers, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Headers &headers, Result Client::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -8904,6 +9075,14 @@ Result Client::Post(const std::string &path, const Headers &headers,
return cli_->Post(path, headers, std::move(content_provider), content_type, return cli_->Post(path, headers, std::move(content_provider), content_type,
progress); progress);
} }
Result Client::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Params &params) { Result Client::Post(const std::string &path, const Params &params) {
return cli_->Post(path, params); return cli_->Post(path, params);
} }
@ -8938,8 +9117,8 @@ Result Client::Post(const std::string &path, const Headers &headers,
const std::string &content_type, const std::string &content_type,
ContentReceiver content_receiver, ContentReceiver content_receiver,
DownloadProgress progress) { DownloadProgress progress) {
return cli_->Post(path, headers, body, content_type, content_receiver, return cli_->Post(path, headers, body, content_type,
progress); std::move(content_receiver), progress);
} }
Result Client::Put(const std::string &path) { return cli_->Put(path); } Result Client::Put(const std::string &path) { return cli_->Put(path); }
@ -8976,12 +9155,28 @@ Result Client::Put(const std::string &path, size_t content_length,
return cli_->Put(path, content_length, std::move(content_provider), return cli_->Put(path, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Put(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, Result Client::Put(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return cli_->Put(path, std::move(content_provider), content_type, progress); return cli_->Put(path, std::move(content_provider), content_type, progress);
} }
Result Client::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Headers &headers, Result Client::Put(const std::string &path, const Headers &headers,
size_t content_length, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
@ -8990,6 +9185,15 @@ Result Client::Put(const std::string &path, const Headers &headers,
return cli_->Put(path, headers, content_length, std::move(content_provider), return cli_->Put(path, headers, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Headers &headers, Result Client::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -8997,6 +9201,14 @@ Result Client::Put(const std::string &path, const Headers &headers,
return cli_->Put(path, headers, std::move(content_provider), content_type, return cli_->Put(path, headers, std::move(content_provider), content_type,
progress); progress);
} }
Result Client::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Params &params) { Result Client::Put(const std::string &path, const Params &params) {
return cli_->Put(path, params); return cli_->Put(path, params);
} }
@ -9072,12 +9284,28 @@ Result Client::Patch(const std::string &path, size_t content_length,
return cli_->Patch(path, content_length, std::move(content_provider), return cli_->Patch(path, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Patch(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, Result Client::Patch(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return cli_->Patch(path, std::move(content_provider), content_type, progress); return cli_->Patch(path, std::move(content_provider), content_type, progress);
} }
Result Client::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Headers &headers, Result Client::Patch(const std::string &path, const Headers &headers,
size_t content_length, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
@ -9086,6 +9314,15 @@ Result Client::Patch(const std::string &path, const Headers &headers,
return cli_->Patch(path, headers, content_length, std::move(content_provider), return cli_->Patch(path, headers, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Headers &headers, Result Client::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -9093,6 +9330,14 @@ Result Client::Patch(const std::string &path, const Headers &headers,
return cli_->Patch(path, headers, std::move(content_provider), content_type, return cli_->Patch(path, headers, std::move(content_provider), content_type,
progress); progress);
} }
Result Client::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Params &params) { Result Client::Patch(const std::string &path, const Params &params) {
return cli_->Patch(path, params); return cli_->Patch(path, params);
} }

View File

@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H #ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.27.0" #define CPPHTTPLIB_VERSION "0.28.0"
#define CPPHTTPLIB_VERSION_NUM "0x001B00" #define CPPHTTPLIB_VERSION_NUM "0x001C00"
/* /*
* Platform compatibility check * Platform compatibility check
@ -257,6 +257,7 @@ using socklen_t = int;
#include <netinet/in.h> #include <netinet/in.h>
#ifdef __linux__ #ifdef __linux__
#include <resolv.h> #include <resolv.h>
#undef _res // Undefine _res macro to avoid conflicts with user code (#2278)
#endif #endif
#include <csignal> #include <csignal>
#include <netinet/tcp.h> #include <netinet/tcp.h>
@ -1421,14 +1422,18 @@ public:
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Params &params); Result Post(const std::string &path, const Params &params);
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers);
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const Params &params); Result Post(const std::string &path, const Headers &headers, const Params &params);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1439,14 +1444,18 @@ public:
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Params &params); Result Put(const std::string &path, const Params &params);
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers);
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const Params &params); Result Put(const std::string &path, const Headers &headers, const Params &params);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1457,14 +1466,18 @@ public:
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Params &params); Result Patch(const std::string &path, const Params &params);
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const Params &params); Result Patch(const std::string &path, const Headers &headers, const Params &params);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1712,17 +1725,19 @@ private:
template <typename ClientType> void setup_redirect_client(ClientType &client); template <typename ClientType> void setup_redirect_client(ClientType &client);
bool handle_request(Stream &strm, Request &req, Response &res, bool handle_request(Stream &strm, Request &req, Response &res,
bool close_connection, Error &error); bool close_connection, Error &error);
std::unique_ptr<Response> send_with_content_provider( std::unique_ptr<Response> send_with_content_provider_and_receiver(
Request &req, const char *body, size_t content_length, Request &req, const char *body, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length, ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, Error &error); const std::string &content_type, ContentReceiver content_receiver,
Result send_with_content_provider( Error &error);
Result send_with_content_provider_and_receiver(
const std::string &method, const std::string &path, const std::string &method, const std::string &path,
const Headers &headers, const char *body, size_t content_length, const Headers &headers, const char *body, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length, ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, UploadProgress progress); const std::string &content_type, ContentReceiver content_receiver,
UploadProgress progress);
ContentProviderWithoutLength get_multipart_content_provider( ContentProviderWithoutLength get_multipart_content_provider(
const std::string &boundary, const UploadFormDataItems &items, const std::string &boundary, const UploadFormDataItems &items,
const FormDataProviderItems &provider_items) const; const FormDataProviderItems &provider_items) const;
@ -1775,14 +1790,18 @@ public:
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Params &params); Result Post(const std::string &path, const Params &params);
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers);
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const Params &params); Result Post(const std::string &path, const Headers &headers, const Params &params);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1793,14 +1812,18 @@ public:
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Params &params); Result Put(const std::string &path, const Params &params);
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers);
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const Params &params); Result Put(const std::string &path, const Headers &headers, const Params &params);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1811,14 +1834,18 @@ public:
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Params &params); Result Patch(const std::string &path, const Params &params);
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers); Result Patch(const std::string &path, const Headers &headers);
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const Params &params); Result Patch(const std::string &path, const Headers &headers, const Params &params);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);

View File

@ -1,6 +0,0 @@
# Remove bssl
file(READ "CMakeLists.txt" content)
string(REPLACE "add_executable(bssl" "#add_executable(bssl" content "${content}")
string(REPLACE "target_link_libraries(bssl" "#target_link_libraries(bssl" content "${content}")
string(REPLACE "install(TARGETS bssl" "#install(TARGETS bssl" content "${content}")
file(WRITE "CMakeLists.txt" "${content}")