diff --git a/CODEOWNERS b/CODEOWNERS index 908d13a35b..6ef6c0489f 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,10 +2,8 @@ # multiplie collaborators per item can be specified /.devops/*.Dockerfile @ngxson -/.github/actions/ @slaren @CISC +/.github/actions/ @CISC /.github/workflows/ @CISC -/.github/workflows/release.yml @slaren -/.github/workflows/winget.yml @slaren /ci/ @ggerganov /cmake/ @ggerganov /common/CMakeLists.txt @ggerganov @@ -40,21 +38,14 @@ /examples/passkey/ @ggerganov /examples/retrieval/ @ggerganov /examples/save-load-state/ @ggerganov -/examples/simple-chat/ @slaren -/examples/simple/ @slaren /examples/speculative-simple/ @ggerganov /examples/speculative/ @ggerganov /ggml/cmake/ @ggerganov -/ggml/include/ @ggerganov @slaren -/ggml/src/ggml-alloc.c @slaren -/ggml/src/ggml-backend* @slaren -/ggml/src/ggml-blas/ @slaren -/ggml/src/ggml-common.h @ggerganov @slaren -/ggml/src/ggml-cpu/ @ggerganov @slaren +/ggml/include/ @ggerganov +/ggml/src/ggml-common.h @ggerganov +/ggml/src/ggml-cpu/ @ggerganov /ggml/src/ggml-cpu/spacemit/ @alex-spacemit -/ggml/src/ggml-cuda/common.cuh @slaren /ggml/src/ggml-cuda/fattn* @JohannesGaessler -/ggml/src/ggml-cuda/ggml-cuda.cu @slaren /ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an /ggml/src/ggml-cuda/mmq.* @JohannesGaessler /ggml/src/ggml-cuda/mmvf.* @JohannesGaessler @@ -62,19 +53,19 @@ /ggml/src/ggml-cuda/fattn-wmma* @IMbackK /ggml/src/ggml-hip/ @IMbackK /ggml/src/ggml-cuda/vendors/hip.h @IMbackK -/ggml/src/ggml-impl.h @ggerganov @slaren +/ggml/src/ggml-impl.h @ggerganov /ggml/src/ggml-metal/ @ggerganov /ggml/src/ggml-opencl/ @lhez @max-krasnyansky /ggml/src/ggml-hexagon/ @max-krasnyansky @lhez /ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-quants.* @ggerganov /ggml/src/ggml-rpc/ @rgerganov -/ggml/src/ggml-threading.* @ggerganov @slaren +/ggml/src/ggml-threading.* @ggerganov /ggml/src/ggml-vulkan/ @0cc4m /ggml/src/ggml-webgpu/ @reeselevine /ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM -/ggml/src/ggml.c @ggerganov @slaren -/ggml/src/ggml.cpp @ggerganov @slaren +/ggml/src/ggml.c @ggerganov +/ggml/src/ggml.cpp @ggerganov /ggml/src/gguf.cpp @JohannesGaessler @Green-Sky /gguf-py/ @CISC /media/ @ggerganov @@ -86,15 +77,11 @@ /src/llama-arch.* @CISC /src/llama-chat.* @ngxson /src/llama-graph.* @CISC -/src/llama-model-loader.* @slaren /src/llama-model.* @CISC /src/llama-vocab.* @CISC /src/models/ @CISC /tests/ @ggerganov -/tests/test-backend-ops.cpp @slaren -/tests/test-thread-safety.cpp @slaren /tools/batched-bench/ @ggerganov -/tools/llama-bench/ @slaren /tools/main/ @ggerganov /tools/mtmd/ @ngxson /tools/perplexity/ @ggerganov @@ -106,8 +93,6 @@ /tools/tokenize/ @ggerganov /tools/tts/ @ggerganov /vendor/ @ggerganov -/.clang-format @slaren -/.clang-tidy @slaren /AUTHORS @ggerganov /CMakeLists.txt @ggerganov /CONTRIBUTING.md @ggerganov diff --git a/common/arg.cpp b/common/arg.cpp index 84db9ca77c..d9c838fcd0 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1237,6 +1237,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { const auto sampler_names = string_split(value, ';'); params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS; } ).set_sparam()); add_opt(common_arg( @@ -1266,6 +1267,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.sampling.temp = std::stof(value); params.sampling.temp = std::max(params.sampling.temp, 0.0f); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP; } ).set_sparam()); add_opt(common_arg( @@ -1273,6 +1275,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), [](common_params & params, int value) { params.sampling.top_k = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K; } ).set_sparam()); add_opt(common_arg( @@ -1280,6 +1283,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), [](common_params & params, const std::string & value) { params.sampling.top_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P; } ).set_sparam()); add_opt(common_arg( @@ -1287,6 +1291,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), [](common_params & params, const std::string & value) { params.sampling.min_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P; } ).set_sparam()); add_opt(common_arg( @@ -1301,6 +1306,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), [](common_params & params, const std::string & value) { params.sampling.xtc_probability = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY; } ).set_sparam()); add_opt(common_arg( @@ -1308,6 +1314,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), [](common_params & params, const std::string & value) { params.sampling.xtc_threshold = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD; } ).set_sparam()); add_opt(common_arg( @@ -1326,6 +1333,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } params.sampling.penalty_last_n = value; params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N; } ).set_sparam()); add_opt(common_arg( @@ -1333,6 +1341,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), [](common_params & params, const std::string & value) { params.sampling.penalty_repeat = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT; } ).set_sparam()); add_opt(common_arg( @@ -1430,6 +1439,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), [](common_params & params, int value) { params.sampling.mirostat = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT; } ).set_sparam()); add_opt(common_arg( @@ -1437,6 +1447,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), [](common_params & params, const std::string & value) { params.sampling.mirostat_eta = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA; } ).set_sparam()); add_opt(common_arg( @@ -1444,6 +1455,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), [](common_params & params, const std::string & value) { params.sampling.mirostat_tau = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU; } ).set_sparam()); add_opt(common_arg( diff --git a/common/common.cpp b/common/common.cpp index be31c66de1..10001f5469 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "sampling.h" #include #include @@ -957,6 +958,58 @@ std::vector fs_list(const std::string & path, bool include_dir // Model utils // +static inline void common_init_sampler_from_model( + const llama_model * model, + common_params_sampling & sparams) { + + const uint64_t config = sparams.user_sampling_config; + + auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[64] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + int32_t v = strtol(buf, &end, 10); + if (end && end != buf) dst = v; + } + }; + + auto get_float = [&](const char * key, float & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[128] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + float v = strtof(buf, &end); + if (end && end != buf) dst = v; + } + }; + + // Sampling sequence + if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) { + char buf[512] = {0}; + if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) { + const std::vector sampler_names = string_split(std::string(buf), ';'); + if (!sampler_names.empty()) { + sparams.samplers = common_sampler_types_from_names(sampler_names, true); + } + } + } + + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); +} + struct common_init_result common_init_from_params(common_params & params) { common_init_result iparams; auto mparams = common_model_params_to_llama(params); @@ -968,6 +1021,8 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + common_init_sampler_from_model(model, params.sampling); + const llama_vocab * vocab = llama_model_get_vocab(model); auto cparams = common_context_params_to_llama(params); diff --git a/common/common.h b/common/common.h index 5569814738..1225aa191d 100644 --- a/common/common.h +++ b/common/common.h @@ -138,6 +138,22 @@ struct common_grammar_trigger { llama_token token = LLAMA_TOKEN_NULL; }; +enum common_params_sampling_config : uint64_t { + COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2, + COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5, + COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11, +}; + + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -170,6 +186,8 @@ struct common_params_sampling { bool no_perf = false; // disable performance metrics bool timing_per_token = false; + uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers + std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6cbaee03df..daaf0bf497 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -565,7 +565,7 @@ class ModelBase: gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF, ) ) - or not new_name.endswith(".weight") + or new_name[-7:] not in (".weight", ".lora_a", ".lora_b") ): data_qtype = gguf.GGMLQuantizationType.F32 @@ -10061,6 +10061,25 @@ class LazyTorchTensor(gguf.LazyBase): torch.uint8: np.uint8, } + # only used when byteswapping data. Only correct size is needed + _dtype_byteswap_map: dict[torch.dtype, type] = { + torch.float64: np.float64, + torch.float32: np.float32, + torch.bfloat16: np.float16, + torch.float16: np.float16, + torch.int64: np.int64, + torch.uint64: np.uint64, + torch.int32: np.int32, + torch.uint32: np.uint32, + torch.int16: np.int16, + torch.uint16: np.uint16, + torch.int8: np.int8, + torch.uint8: np.uint8, + torch.bool: np.uint8, + torch.float8_e4m3fn: np.uint8, + torch.float8_e5m2: np.uint8, + } + # used for safetensors slices # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 @@ -10104,8 +10123,14 @@ class LazyTorchTensor(gguf.LazyBase): @classmethod def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor: def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[tensor.dtype] - return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape) + numpy_dtype = cls._dtype_byteswap_map[dtype] + return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape) dtype = cls._dtype_str_map[t.dtype] shape = t.shape lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r)) @@ -10113,10 +10138,16 @@ class LazyTorchTensor(gguf.LazyBase): @classmethod def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[remote_tensor.dtype] + numpy_dtype = cls._dtype_byteswap_map[dtype] shape = remote_tensor.shape meta = cls.meta_with_dtype_and_shape(dtype, shape) - lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape)) return cast(torch.Tensor, lazy) @classmethod diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 57c6cd0df1..b0adde8a8b 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -242,7 +242,7 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 605fcfcb9c..4dbca868bc 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -530,6 +530,7 @@ extern "C" { GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, + GGML_OP_TOP_K, GGML_OP_LEAKY_RELU, GGML_OP_TRI, GGML_OP_FILL, @@ -2258,18 +2259,25 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); + // similar to ggml_top_k but implemented as `argsort` + `view` + GGML_API struct ggml_tensor * ggml_argsort_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k); + + // top k elements per row + // note: the resulting top k indices are in no particular order + GGML_API struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k); + GGML_API struct ggml_tensor * ggml_arange( struct ggml_context * ctx, float start, float stop, float step); - // top k elements per row - GGML_API struct ggml_tensor * ggml_top_k( - struct ggml_context * ctx, - struct ggml_tensor * a, - int k); - #define GGML_KQ_MASK_PAD 64 // q: [n_embd_k, n_batch, n_head, ne3 ] diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 606c6d1783..48f4b7db69 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -2206,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx, } /** - * @brief Initializes and caches sine/cosine positional encoding values - * (used in RoPE, Rotary Position Embedding) for attention layers. + * @brief Initializes and caches all intermediate tensors required for RoPE + * (Rotary Position Embedding), including support for Yarn, mRoPE, + * i-mRoPE, Neox repeat strategy, independent sectors, frequency factors, + * and multi-section rotary groups. * - * This function computes and caches the sin/cos values of - * θ = position * theta_scale for RoPE encoding. The cache is shared - * across attention layers, and only the first attention layer will - * trigger initialization. The cache includes repeated sin/cos values - * with different repeat methods depending on the @param is_neox flag. + * This function computes and caches the per-dimension θ coefficients used for + * Q/K rotary embedding. The cache is shared across layers, and recomputed only + * when any dependent parameter changes. * - * Steps performed by this function: - * 1. Identify whether the target tensor belongs to Q/K in attention - * and restrict computation to the first layer only. - * 2. Initialize the theta scale array (arange → power → freq scaling). - * 3. Allocate sin/cos caches if the max prompt length increases. - * 4. Compute θ = position * theta_scale. - * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor. - * 6. Expand sin/cos values by repeat or repeat_interleave depending - * on whether @param is_neox is enabled. + * The function now supports: + * - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor) + * - Per-dimension independent sector exponent rules (indep_sects + sections[]) + * - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope) + * - Frequency factor division (src2) + * - Neox / normal repeat expansion modes * - * @param ctx The CANN backend context, holding memory pool, - * stream, and persistent buffers for rope init/cache. - * @param dst The destination ggml_tensor whose computation - * depends on the RoPE values (usually Qcur/Kcur). - * @param theta_scale Scalar exponent base for computing theta scale values. - * @param freq_scale Frequency scaling factor, applied to theta scale. - * @param attn_factor Attention scaling factor, applied to sin/cos. - * @param is_neox Whether to use Neox-style repeat strategy - * (dim expansion vs repeat_interleave). + * @param ctx CANN backend context, containing memory pool, + * cached buffers, and runtime stream. + * @param dst Destination ggml_tensor whose computation + * depends on RoPE (typically Qcur or Kcur). + * @param corr_dims [low, high] Yarn correction range. + * @param ext_factor Yarn extrapolation strength. 0 = disabled. + * @param theta_scale Base multiplier for per-dimension θ exponent. + * @param freq_scale Global frequency scaling factor. + * @param attn_factor Optional scaling applied to sin/cos (if needed). + * @param is_neox Whether to use Neox-style dimension interleave. + * @param sections 4-way sector sizes for independent-section RoPE + * and multi-section mRoPE (t/h/w/e). + * @param mrope_used Whether to enable multi-section rotary embedding. + * @param is_imrope Whether to apply interleaved mRoPE rules. + * @param indep_sects Whether each dimension runs independent exponent + * resets based on @p sections. */ -static void aclnn_cache_init(ggml_backend_cann_context & ctx, - ggml_tensor * dst, - float * corr_dims, - float ext_factor, - float theta_scale, - float freq_scale, - float attn_factor, - bool is_neox) { +static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, + ggml_tensor * dst, + float * corr_dims, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + int sections[4], + bool mrope_used, + bool is_imrope, + bool indep_sects) { ggml_tensor * src0 = dst->src[0]; // input ggml_tensor * src1 = dst->src[1]; // position ggml_tensor * src2 = dst->src[2]; // freq_factors - if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor && - ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale && - ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) { + int64_t theta_scale_length = src0->ne[0] / 2; + int64_t position_length = dst->ne[2]; + + // TODO: check theta_scale_length and position_length. + if (src2 == nullptr && ctx.rope_cache.cached && + ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, + is_neox, indep_sects, mrope_used, is_imrope, sections)) { // use cache. return; } - int64_t theta_scale_length = src0->ne[0] / 2; - int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; - size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) }; + // Step0: calculate tensor shape. + int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; + size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float), + theta_scale_length * sizeof(float) }; GGML_ASSERT(src1->type == GGML_TYPE_I32); - int64_t position_length = src1->ne[0]; - int64_t position_ne[] = { 1, 1, position_length, 1 }; - size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; + int64_t position_ne[] = { 1, 1, position_length, 1 }; + size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; - int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 }; - size_t theta_nb[GGML_MAX_DIMS]; - theta_nb[0] = sizeof(float); + int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 }; + size_t cache_nb[GGML_MAX_DIMS]; + cache_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { - theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; + cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1]; } - // theta_scale arange, [0,1,...,ne00/2 - 1] + // Step1: Compute the coefficient of theta. During the cache_init process, aside from + // (1) multiplying by the position, + // (2) dividing by freq_factors, + // (3) computing the sine and cosine, + // the other parameters used in the computation generally do not change in most scenarios. + // Therefore, we can first compute this part of the result and then cache it. + + // Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor. acl_tensor_ptr acl_theta_scale_tensor; - // cache theta scale - if (ctx.rope_cache.theta_scale_length != theta_scale_length || - // theta_scale and freq_scale should not change during the current token inference process, - // so we can directly use == here instead of comparing the absolute difference. - ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) { - ctx.rope_cache.theta_scale_length = theta_scale_length; + bool theta_scale_updated = false; + if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale || + ctx.rope_cache.indep_sects != indep_sects) { + theta_scale_updated = true; + if (ctx.rope_cache.theta_scale_exp_host != nullptr) { + free(ctx.rope_cache.theta_scale_exp_host); + } + ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float)); + GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr); + if (!indep_sects) { + ctx.rope_cache.theta_scale_exp_host[0] = 1; + for (int i = 1; i < theta_scale_length; i++) { + ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale; + } + } else { + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + + ctx.rope_cache.theta_scale_exp_host[0] = 1; + for (int i = 1; i < theta_scale_length; i++) { + int sector = i % sect_dims; + if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) { + ctx.rope_cache.theta_scale_exp_host[i] = 1; + continue; + } + ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale; + } + } if (ctx.rope_cache.theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); @@ -2285,74 +2328,138 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), + ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float), + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + } - float start = 0; - float step = 1; - float stop = theta_scale_length; - float n_elements = theta_scale_length; - aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements); + // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. + bool yarn_ramp_tensor_updated = false; + ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); + acl_tensor_ptr acl_yarn_ramp_tensor; + if (ext_factor != 0 && + // TODO: check more parameter. + (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { + yarn_ramp_tensor_updated = true; - ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); - acl_tensor_ptr acl_yarn_ramp_tensor; - if (ext_factor != 0) { - // -rope_yarn_ramp - // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); - // return MIN(1, MAX(0, y)) - 1; - yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); - void * yarn_ramp_buffer = yarn_ramp_allocator.get(); - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); - float zero_value = 0, one_value = 1; - float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); - acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); - acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT); + // -rope_yarn_ramp + // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + // return MIN(1, MAX(0, y)) - 1; + yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); + void * yarn_ramp_buffer = yarn_ramp_allocator.get(); + acl_yarn_ramp_tensor = + ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + float zero_value = 0, one_value = 1; + float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); + acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); + acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(), - acl_yarn_ramp_tensor.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get()); + aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get()); - // theta_interp = freq_scale * theta_extrap; - // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; - // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; - // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); - // - // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse - // cache freq_scale + (freq_scale - 1) * ramp_mix - float freq_scale_1 = freq_scale - 1; - acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT); - acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); - } + // theta_interp = freq_scale * theta_extrap; + // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; + // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); + // + // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse + // cache freq_scale + (freq_scale - 1) * ramp_mix + float freq_scale_1 = freq_scale - 1; + acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT); + acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); + } - // power - acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(), - acl_theta_scale_tensor.get()); - - if (ext_factor != 0) { + // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. + if (ext_factor != 0) { + if (theta_scale_updated || yarn_ramp_tensor_updated) { + theta_scale_updated = true; aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get()); - } else if (freq_scale != 1) { - aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true); } } else { - // use cache + if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) { + theta_scale_updated = true; + aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true); + } + } + + // Nothing changed, use cache. + if (!theta_scale_updated) { acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); } + // Step 1.4: prepare select index if mrope + acl_tensor_ptr position_select_index_tensor; + if (mrope_used) { + if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] || + ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] || + ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) { + if (ctx.rope_cache.position_select_index_host != nullptr) { + free(ctx.rope_cache.position_select_index_host); + } + ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int)); + GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr); + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + // t,h,w,e + for (int i = 0; i < theta_scale_length; i++) { + int sector = i % sect_dims; + + if (is_imrope) { // qwen3vl apply interleaved mrope + if (sector % 3 == 1 && sector < 3 * sections[1]) { + ctx.rope_cache.position_select_index_host[i] = 1; + } else if (sector % 3 == 2 && sector < 3 * sections[2]) { + ctx.rope_cache.position_select_index_host[i] = 2; + } else if (sector % 3 == 0 && sector < 3 * sections[0]) { + ctx.rope_cache.position_select_index_host[i] = 0; + } else { + ctx.rope_cache.position_select_index_host[i] = 3; + } + } else { + if (sector >= sections[0] && sector < sec_w) { + ctx.rope_cache.position_select_index_host[i] = 1; + } else if (sector >= sec_w && sector < sec_e) { + ctx.rope_cache.position_select_index_host[i] = 2; + } else if (sector >= sec_e) { + ctx.rope_cache.position_select_index_host[i] = 3; + } else { + ctx.rope_cache.position_select_index_host[i] = 0; + } + } + } + + if (ctx.rope_cache.position_select_index != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index)); + } + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int), + ACL_MEM_MALLOC_HUGE_FIRST)); + + ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int), + ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int), + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + } + + position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32, + sizeof(int), theta_scale_ne, theta_scale_nb, 1); + } + + // Step2: divide by freq_factors ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool()); - // freq_factors if (src2) { freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float)); void * freq_fac_res_ptr = freq_fac_res_allocator.get(); @@ -2365,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor); } + // Step3: prepare position_tensor + acl_tensor_ptr acl_position_tensor; + ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool()); + if (mrope_used) { + // Step3.1: select current position; + // position : + // pos1: [[0, 1 ,2 ,3 ], + // pos2: [4, 5 ,6 ,7 ], + // pos3: [8, 9 ,10,11], + // pos4: [12,13,14,15] ] + // + // select index = [0, 1, 2, 2, 1, 0] + // + // selected_tensor: + // [[0, 1 ,2 ,3 ], + // [4, 5 ,6 ,7 ], + // [8, 9 ,10,11], + // [8, 9 ,10,11], + // [4, 5 ,6 ,7 ], + // [0, 1 ,2 ,3 ]] + // + // transpose, from [seq_len:dims] to [dims:seq_len] + // [0, 4, 8 ,8 ,4, 0], + // [1, 5, 9, 9, 5, 1], + // [2, 6, 10,10,6 ,2], + // [3, 7, 11,11,7 3 ]] + // + // multipy by theta_scale_tensor + // [theta_scale^0, theta_scale^1, ..., theta_scale ^ n] + + int64_t mrope_position_ne[] = { position_length, 4 }; + size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) }; + acl_tensor_ptr mrope_position = + ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), + mrope_position_ne, mrope_position_nb, 2); + + // selected position tensor's shape is a transpose of cache tensor. + int64_t selected_position_ne[] = { position_length, theta_scale_length }; + size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) }; + mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float)); + void * mrope_position_buffer = mrope_position_acllocator.get(); + acl_position_tensor = + ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(), + acl_position_tensor.get()); + + // transpose + int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 }; + size_t transposed_nb[GGML_MAX_DIMS]; + transposed_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1]; + } + + std::swap(transposed_ne[0], transposed_ne[2]); + std::swap(transposed_nb[0], transposed_nb[2]); + + acl_position_tensor = + ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS); + + } else { + // auto bcast. + acl_position_tensor = + ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), + position_ne, position_nb, GGML_MAX_DIMS); + } + + // Step4: multiply by the position + int64_t theta_length = theta_scale_length * position_length; + ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float)); + void * theta_buffer = theta_allocator.get(); + + acl_tensor_ptr acl_theta_tensor = + ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS); + aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get()); + + // Step5: calculate sin cos. // init sin_repeat && cos_repeat, only to accelerate first layer on each device if (position_length > ctx.rope_cache.position_length) { ctx.rope_cache.position_length = position_length; @@ -2381,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); } - // position - acl_tensor_ptr acl_position_tensor = - ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne, - position_nb, GGML_MAX_DIMS); - - // power * position - int64_t theta_length = theta_scale_length * position_length; - ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float)); - void * theta_buffer = theta_allocator.get(); - - acl_tensor_ptr acl_theta_tensor = - ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get()); - // sin/cos ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float)); void * sin_buffer = sin_allocator.get(); acl_tensor_ptr acl_sin_tensor = - ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get()); ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float)); void * cos_buffer = cos_allocator.get(); acl_tensor_ptr acl_cos_tensor = - ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get()); if (ext_factor != 0) { attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - // attn_factor + // Step 5: multiply by attn_factor if (attn_factor != 1) { aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true); aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true); } - int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 }; + int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2429,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - // repeat + // Step 6: repeat if (is_neox) { + // [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn] int64_t repeatsArray[] = { 1, 1, 1, 2 }; aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray); aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray); @@ -2438,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, int64_t num_repeats = 2; int64_t dim = 3; int64_t output_size = theta_scale_length * num_repeats; + // [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn] aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size); } - // Other layers use cache except first layer. - ctx.rope_cache.cached = true; - ctx.rope_cache.ext_factor = ext_factor; - ctx.rope_cache.theta_scale = theta_scale; - ctx.rope_cache.freq_scale = freq_scale; - ctx.rope_cache.attn_factor = attn_factor; - ctx.rope_cache.is_neox = is_neox; + // Update cached value. + ctx.rope_cache.cached = true; + ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, + indep_sects, mrope_used, is_imrope, sections); } #ifdef __cplusplus @@ -2474,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; // const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -2482,12 +2654,13 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_TENSOR_UNARY_OP_LOCALS - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); // TODO: n_dims <= ne0 GGML_ASSERT(n_dims == ne0); @@ -2498,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (mrope_used) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne0/2); + } + + if (is_imrope || mrope_used) { + is_neox = true; + } // init ctx.rope_cos/rope_sin cache - aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox); + aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision); int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; @@ -2657,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { return; #endif - // ggml_mode = 0 --> aclnn_model = 1 - int64_t acl_mode = mode == 0 ? 1 : mode; + int64_t acl_mode = is_neox ? 0 : 1; switch (src0->type) { case GGML_TYPE_F32: @@ -3236,3 +3423,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst GGML_ABORT("Function is not implemented."); } } + +static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // weight + ggml_tensor * src1 = dst->src[1]; // input + GGML_TENSOR_BINARY_OP_LOCALS + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); + + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; + + const int64_t i12 = i2; + const int64_t i13 = i3; + acl_tensor_ptr accumulator = + ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst->ne, dst->nb, 2); + + // The outer product needs to be accumulated in this dimension. + for (int64_t i1 = 0; i1 < ne11; i1++) { + acl_tensor_ptr acl_input = ggml_cann_create_tensor( + (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src1->ne, src1->nb, 1); + + acl_tensor_ptr acl_weight = ggml_cann_create_tensor( + (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, 1); + + ggml_cann_pool_alloc output_allocator(ctx.pool()); + void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); + acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst->ne, dst->nb, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get()); + float alpha_value = 1.0f; + aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha); + } + } + } +} + +void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + const enum ggml_type type = src0->type; + + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + ggml_cann_out_prod_fp(ctx, dst); + break; + default: + GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD"); + break; + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index a6c2eb1226..1ebbc769c7 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1125,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::functionsrc[0]` and `dst->src[1]`. + * + * @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation + */ +void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index d4ef24eaa7..b17445bb9a 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache { struct ggml_cann_rope_cache { ~ggml_cann_rope_cache() { - if (theta_scale_cache != nullptr) { + if (theta_scale_cache) { ACL_CHECK(aclrtFree(theta_scale_cache)); } - if (sin_cache != nullptr) { + if (sin_cache) { ACL_CHECK(aclrtFree(sin_cache)); } - if (cos_cache != nullptr) { + if (cos_cache) { ACL_CHECK(aclrtFree(cos_cache)); } + if (position_select_index) { + ACL_CHECK(aclrtFree(position_select_index)); + } + if (theta_scale_exp_host) { + free(theta_scale_exp_host); + } + if(position_select_index_host) { + free(position_select_index_host); + } } - void * theta_scale_cache = nullptr; - int64_t theta_scale_length = 0; + bool equal(int64_t theta_scale_length, + int64_t position_length, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + bool indep_sects, + bool mrope_used, + bool is_imrope, + int sections[4]) { + return this->theta_scale_length == theta_scale_length && this->position_length == position_length && + this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale && + this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects && + this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] && + this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3]; + } + + void set(int64_t theta_scale_length, + int64_t position_length, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + bool indep_sects, + bool mrope_used, + bool is_imrope, + int sections[4]) { + this->theta_scale_length = theta_scale_length; + this->position_length = position_length; + this->ext_factor = ext_factor; + this->theta_scale = theta_scale; + this->freq_scale = freq_scale; + this->attn_factor = attn_factor; + this->is_neox = is_neox; + this->indep_sects = indep_sects; + this->mrope_used = mrope_used; + this->is_imrope = is_imrope; + this->sections[0] = sections[0]; + this->sections[1] = sections[1]; + this->sections[2] = sections[2]; + this->sections[3] = sections[3]; + } + + // memory cache, prepare before inferencing. + void * theta_scale_cache = nullptr; + float * theta_scale_exp_host = nullptr; + int * position_select_index_host = nullptr; + void * position_select_index = nullptr; // sin/cos cache, used only to accelerate first layer on each device - void * sin_cache = nullptr; - void * cos_cache = nullptr; - int64_t position_length = 0; + void * sin_cache = nullptr; + void * cos_cache = nullptr; // Properties to check before reusing the sincos cache - bool cached = false; - float ext_factor = 0.0f; - float theta_scale = 0.0f; - float freq_scale = 0.0f; - float attn_factor = 0.0f; - bool is_neox = false; + int64_t theta_scale_length = 0; + int64_t position_length = 0; + bool cached = false; + float ext_factor = 0.0f; + float theta_scale = 0.0f; + float freq_scale = 0.0f; + float attn_factor = 0.0f; + bool is_neox = false; + bool indep_sects = false; + bool mrope_used = false; + int sections[4] = { 0, 0, 0, 0 }; + bool is_imrope = false; }; struct ggml_cann_tensor_cache { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3c67c48ffa..df28d67fb0 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1886,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_FLASH_ATTN_EXT: ggml_cann_flash_attn_ext(ctx, dst); break; + case GGML_OP_OUT_PROD: + ggml_cann_out_prod(ctx, dst); + break; default: return false; } @@ -2477,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return false; } - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } if (op->src[0]->ne[0] > 896) { return false; } @@ -2563,6 +2559,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_OUT_PROD: + { + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + } case GGML_OP_CONV_TRANSPOSE_1D: // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255. return (op->src[0]->ne[0] - 1) <= 255; diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index feb5617386..7e53a57b7b 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) include(CheckCXXSourceCompiles) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) - set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}") + string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}") + set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}") foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) set(ARM_FEATURE "HAVE_${feature}") check_cxx_source_compiles( diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index c7348cc26c..3247af8bb0 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argsort(params, tensor); } break; + case GGML_OP_TOP_K: + { + ggml_compute_forward_top_k(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: @@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan( cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; + case GGML_OP_TOP_K: + { + cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks; + } break; case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne10 = node->src[1]->ne[0]; // DK diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 41e89d83c2..d405696539 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding( // ggml_compute_forward_argsort template -struct argsort_cmp { +struct cmp_argsort { const float * data; bool operator()(int32_t a, int32_t b) const { if constexpr (order == GGML_SORT_ORDER_ASC) { @@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32( switch (order) { case GGML_SORT_ORDER_ASC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; case GGML_SORT_ORDER_DESC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; default: @@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort( } } +// ggml_compute_forward_top_k + +struct cmp_top_k { + const float * data; + bool operator()(int32_t a, int32_t b) const { + return data[a] > data[b]; + } +}; + +static void ggml_compute_forward_top_k_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + const int top_k = ne0; + + int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int64_t i = ith; i < nr; i += nth) { + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne00; j++) { + tmp[j] = j; + } + + std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data}); + + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + + std::copy(tmp, tmp + top_k, dst_data); + + // emphasize that the order is not important + if (top_k > 1) { + std::swap(dst_data[0], dst_data[1]); + } + } +} + +void ggml_compute_forward_top_k( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_top_k_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 98a0eec16d..0fdfee7976 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index e0e9540433..bd80805fdc 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const } inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { -#if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; - const int ggml_f16_step = 8 * ggml_f16_epr; +#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 8 * ggml_f16_epr; - GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np= (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; - for (int i = 0; i < np; i += ggml_f16_step) { - ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f16_step) { + ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); - GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); + GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); - ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); + ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); - GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); + GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); - ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); + ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); - GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); + GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); - ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); + ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); - GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); + GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); - ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); + ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); - GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); + GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); - ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); + ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); - GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); + GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); - ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); + ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); - GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); + GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); - ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); + ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); - GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); + GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); + } + const int np2 = (n & ~(ggml_f16_epr - 1)); + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + ry = GGML_F16x_VEC_FMA(ry, rx, vx); + + GGML_F16x_VEC_STORE(y + k, ry, 0); + } + + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + hy = svmad_f16_x(pg, hx, vx, hy); + svst1_f16(pg, (__fp16 *)(y + np2), hy); + } + np = n; +#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic + const int np = n; + _Float16 hv = (_Float16)v; + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e16m8(n - i); + vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl); + vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl); + vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl); + __riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl); + } +#elif defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } - const int np2 = (n & ~(ggml_f16_epr - 1)); - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); - ry = GGML_F16x_VEC_FMA(ry, rx, vx); - - GGML_F16x_VEC_STORE(y + k, ry, 0); - } - - if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); - hy = svmad_f16_x(pg, hx, vx, hy); - svst1_f16(pg, (__fp16 *)(y + np2), hy); - } - - #elif defined(__riscv_v_intrinsic) - // todo: RVV impl - // scalar - for (int i = 0; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); - } - #else - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); - } - #endif + } #else - // scalar - for (int i = 0; i < n; ++i) { + const int np = 0; +#endif + + // leftovers + for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } -#endif } // xs and vs are byte strides of x and v diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c3c4b77996..c0a9c2c08a 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -73,34 +73,7 @@ namespace ggml_cuda_mma { static constexpr int I = I_; static constexpr int J = J_; -#if defined(GGML_USE_HIP) -#if defined(RDNA4) - static constexpr int ne = I * J / 32; - T x[ne] = {0}; - - static constexpr __device__ bool supported() { - if (I == 16 && J == 16) return true; - return false; - } - - static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 16 && J == 16) { - return 8 * (threadIdx.x / 16) + l; - } else { - NO_DEVICE_CODE; - return -1; - } - } - - static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 16 && J == 16) { - return threadIdx.x % 16; - } else { - NO_DEVICE_CODE; - return -1; - } - } -#else +#if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; T x[ne] = {0}; @@ -146,7 +119,6 @@ namespace ggml_cuda_mma { return -1; } } -#endif // defined(RDNA4) #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -177,6 +149,34 @@ namespace ggml_cuda_mma { return -1; } } +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 16) { + return 8 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#endif #else static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -437,7 +437,29 @@ namespace ggml_cuda_mma { xi[0] = xs[0]; } #elif defined(AMD_WMMA_AVAILABLE) - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); + if constexpr (std::is_same_v || std::is_same_v) { + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); + + } else if constexpr (std::is_same_v) { + if constexpr (I == 16 && J == 4) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + }else if constexpr (I == 16 && J == 8) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); + xi[1] = xs1[0]; + + }else{ + NO_DEVICE_CODE; + } + } else { + NO_DEVICE_CODE; + } #else #pragma unroll for (int l = 0; l < t.ne; ++l) { @@ -772,6 +794,36 @@ namespace ggml_cuda_mma { acc[0], 0, 0, 0); #endif // defined(CDNA3) + +#elif defined(AMD_WMMA_AVAILABLE) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + +#if defined(RDNA4) + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + true + ); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[1], + true, + b_vec[1], + acc[0], + true + ); +#endif // defined(RDNA4) + #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -798,6 +850,7 @@ namespace ggml_cuda_mma { acc[0], 0, 0, 0); #endif // defined(CDNA3) + #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -842,4 +895,31 @@ namespace ggml_cuda_mma { mma(D16[1], A16[1], B); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE } + +static __device__ __forceinline__ void mma( + tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { +#if defined(AMD_WMMA_AVAILABLE) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif + } } + diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index a2c8760abe..03ceba874d 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + if (amd_wmma_available(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return true; + } + } + + return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 2e133b6bda..82468b384e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -92,7 +92,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : + return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -102,7 +102,7 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) @@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { @@ -231,7 +231,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - if (amd_mfma_available(cc)) { + if (amd_mfma_available(cc) || amd_wmma_available(cc)) { return mmq_x >= 128 ? 32 : 16; } else if (turing_mma_available(cc) && mmq_x >= 48) { return 16; @@ -240,7 +240,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) { } } -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 128 ? 32 : 16; } @@ -265,7 +265,7 @@ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { #endif // (GGML_USE_HIP) static constexpr __device__ int mmq_get_nwarps_device() { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 8; #else return 256/ggml_cuda_get_physical_warp_size(); @@ -279,14 +279,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); constexpr int nrows = warp_size / threads_per_row; @@ -305,7 +305,7 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else @@ -327,11 +327,11 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -382,14 +382,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); constexpr int nrows = warp_size / threads_per_row; @@ -408,12 +408,12 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; @@ -430,11 +430,11 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -485,14 +485,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); constexpr int nrows = warp_size / threads_per_row; @@ -527,13 +527,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; @@ -550,11 +550,11 @@ template static __device__ __forceinline__ void loa const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -563,14 +563,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); constexpr int nrows = warp_size / threads_per_row; @@ -603,13 +603,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; @@ -626,11 +626,11 @@ template static __device__ __forceinline__ void loa const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -639,14 +639,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp constexpr int threads_per_row = 32; @@ -665,13 +665,13 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; @@ -688,11 +688,11 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -701,14 +701,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); constexpr int nrows = warp_size / threads_per_row; @@ -730,13 +730,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); const int k0 = kbx * (2 * QI_MXFP4) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; @@ -753,11 +753,11 @@ template static __device__ __forceinline__ void loa const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; #else x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -796,7 +796,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -927,7 +927,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } } } -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -965,7 +965,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -1087,7 +1087,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } } } -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } // Used for Q3_K, IQ2_S, and IQ2_XS @@ -1170,6 +1170,54 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_C C; mma(C, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -1257,21 +1305,21 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; @@ -1295,11 +1343,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; @@ -1310,11 +1358,11 @@ template static __device__ __forceinline__ void loa const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -1438,6 +1486,72 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_C Cd; mma(Cd, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -1574,7 +1688,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( @@ -1582,7 +1696,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1618,11 +1732,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -1649,7 +1763,7 @@ template static __device__ __forceinline__ void loa const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; @@ -1659,10 +1773,10 @@ template static __device__ __forceinline__ void loa } #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)) #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; @@ -1675,7 +1789,7 @@ template static __device__ __forceinline__ void loa x_df[i] = bxi->d; } -#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE) } template @@ -1728,7 +1842,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else @@ -1736,7 +1850,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); constexpr int nrows = warp_size / threads_per_row; @@ -1753,19 +1867,19 @@ template static __device__ __forceinline__ void loa const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, txi); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) // Need if on AMD instead of % because warp_size == 64 // This causes double work and throughput loss (MI300X) // H100 loses about 100 t/s with 'if' condition over '%' @@ -1774,7 +1888,7 @@ template static __device__ __forceinline__ void loa #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if (need_check) { i = min(i, i_max); } @@ -1829,7 +1943,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -1872,7 +1986,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1908,16 +2022,16 @@ template static __device__ __forceinline__ void loa const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { @@ -1930,7 +2044,7 @@ template static __device__ __forceinline__ void loa #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if (need_check) { i = min(i, i_max); } @@ -1986,7 +2100,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -2029,7 +2143,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); @@ -2038,7 +2152,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); constexpr int nrows = warp_size / threads_per_row; @@ -2065,13 +2179,13 @@ template static __device__ __forceinline__ void loa const int kq0 = 2*txi - txi % (QI6_K/2) + 0; const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } #pragma unroll @@ -2084,11 +2198,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 4; @@ -2102,11 +2216,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2190,6 +2304,56 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_C C; mma(C, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -2303,7 +2467,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_iq4_nl( @@ -2311,14 +2475,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); constexpr int nrows = warp_size / threads_per_row; @@ -2340,13 +2504,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = kbx * (2 * QI4_NL) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; @@ -2363,11 +2527,11 @@ template static __device__ __forceinline__ void loa const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2376,14 +2540,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2414,22 +2578,22 @@ template static __device__ __forceinline__ void loa const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2438,14 +2602,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2472,24 +2636,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2498,15 +2662,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; @@ -2539,24 +2702,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2565,14 +2728,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2601,22 +2764,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2625,14 +2788,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2668,22 +2831,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2692,14 +2855,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); constexpr int nrows = warp_size / threads_per_row; @@ -2727,23 +2890,23 @@ template static __device__ __forceinline__ void loa const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2752,14 +2915,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); constexpr int nrows = warp_size / threads_per_row; @@ -2779,13 +2942,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = 8 * (kqsx / 4) + kqsx % 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 8; @@ -2804,11 +2967,11 @@ template static __device__ __forceinline__ void loa const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2848,7 +3011,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int tileC_IJ = mmq_get_granularity_device(0); typedef tile tile_C; constexpr int rows_per_warp = granularity; @@ -2859,11 +3022,11 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); #else GGML_UNUSED(nwarps); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { @@ -3063,13 +3226,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -3538,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 0eefc0b137..329500a03e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l return res; } +// note: reuse the argsort kernel for top_k +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + // note: the top_k kernel is always descending order + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 39ee6e3427..3976e622b9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index acf9dfd5fc..09b1b50311 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LEAKY_RELU: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 0fae97029f..342dc4f8c3 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -832,14 +832,19 @@ typedef struct { } ggml_metal_kargs_leaky_relu; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; } ggml_metal_kargs_argsort; typedef struct { @@ -851,6 +856,11 @@ typedef struct { uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; int32_t len; } ggml_metal_kargs_argsort_merge; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e46970dad4..9871e976f2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_argsort(ctx, idx); } break; + case GGML_OP_TOP_K: + { + n_fuse = ggml_metal_op_top_k(ctx, idx); + } break; case GGML_OP_LEAKY_RELU: { n_fuse = ggml_metal_op_leaky_relu(ctx, idx); @@ -3678,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_argsort args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nth, }; ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); ggml_metal_kargs_argsort_merge args_merge = { - .ne00 = ne00, - .ne01 = ne01, - .ne02 = ne02, - .ne03 = ne03, - .nb00 = nb00, - .nb01 = nb01, - .nb02 = nb02, - .nb03 = nb03, - .len = len, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ ne00, + /*.len =*/ len, }; // merges per row @@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + + // bitonic sort requires the number of elements to be power of 2 + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + // blocks per row + const int npr = (ne00 + nth - 1)/nth; + + const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]); + + if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { + std::swap(bid_dst, bid_tmp); + } + + const int top_k = ne0; + + ggml_metal_kargs_argsort args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices + }; + + if (npr > 1) { + args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k); + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + + int len = args.top_k; + + while (len < args.ne0) { + ggml_metal_op_concurrency_reset(ctx); + + // merges per row + const int nm = (args.ne0 + 2*len - 1) / (2*len); + + const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge))); + + ggml_metal_kargs_argsort_merge args_merge = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements + /*.len =*/ len, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_merge); + ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); + + std::swap(bid_dst, bid_tmp); + + len <<= 1; + } + + return 1; +} + int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 332e550ee7..b5546146e1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index f6033ddc97..70bf6f3d98 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ { res *= 2; } break; + case GGML_OP_TOP_K: + { + res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]); + } break; default: break; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 59e5761704..73b45c762d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32( ushort3 ntg[[threads_per_threadgroup]]) { // bitonic sort const int col = tpitg[0]; + const int ib = tgpig[0] / args.ne01; - const int i00 = (tgpig[0]/args.ne01)*ntg.x; - const int i01 = tgpig[0]%args.ne01; - const int i02 = tgpig[1]; - const int i03 = tgpig[2]; + const int i00 = ib*ntg.x; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); @@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32( } } + const int64_t i0 = ib*args.top_k; + // copy the result to dst without the padding - if (i00 + col < args.ne00) { - dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03; + if (i0 + col < args.ne0 && col < args.top_k) { + dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03; dst[col] = shmem_i32[col]; } @@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32( const int start = im * (2 * args.len); - const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); - const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); + const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); const int total = len0 + len1; device const int32_t * tmp0 = tmp + start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.ne0 + + i02*args.ne0*args.ne01 + + i03*args.ne0*args.ne01*args.ne02; device const int32_t * tmp1 = tmp0 + args.len; dst += start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.top_k + + i02*args.top_k*args.ne01 + + i03*args.top_k*args.ne01*args.ne02; device const float * src0_row = (device const float *)(src0 + args.nb01*i01 @@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32( const int chunk = (total + ntg.x - 1) / ntg.x; const int k0 = tpitg.x * chunk; - const int k1 = min(k0 + chunk, total); + const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); + + if (k0 >= args.top_k) { + return; + } if (k0 >= total) { return; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bc8d3cdcb5..05e585ebff 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -705,6 +705,7 @@ struct vk_device_struct { vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; @@ -1629,6 +1630,22 @@ class vk_perf_logger { timings[name].push_back(time); return; } + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + const ggml_tensor * dst = node; + const ggml_tensor * q = node->src[0]; + const ggml_tensor * k = node->src[1]; + const ggml_tensor * v = node->src[2]; + const ggml_tensor * m = node->src[3]; + std::stringstream name; + name << ggml_op_name(node->op) << + " dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " << + " q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " << + " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << + " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << + " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; + timings[name.str()].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -2485,9 +2502,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { if (hsv >= 192) { return 2; + } else if ((hsv | hsk) & 8) { + return 4; } else { return 8; } @@ -2519,9 +2538,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; } else { - return {get_fa_scalar_num_large_rows(hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; } } } @@ -3950,6 +3969,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); #define IM2COL(bda) \ @@ -7724,7 +7745,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -7855,7 +7876,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); @@ -8439,6 +8460,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_sum_rows_f32; } return nullptr; + case GGML_OP_CUMSUM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cumsum_f32; + } + return nullptr; case GGML_OP_ARGMAX: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { return ctx->device->pipeline_argmax_f32; @@ -8803,6 +8829,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: case GGML_OP_MEAN: case GGML_OP_ARGMAX: { @@ -10132,6 +10159,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p); } +static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p); +} + static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }); } @@ -11731,6 +11763,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node); + break; + case GGML_OP_CUMSUM: + ggml_vk_cumsum(ctx, compute_ctx, src0, node); + break; case GGML_OP_MEAN: ggml_vk_mean(ctx, compute_ctx, src0, node); @@ -13768,6 +13804,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_CUMSUM: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); + } + return false; + } case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: @@ -14418,6 +14463,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_CUMSUM) { + tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_MEAN) { tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_ARGMAX) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp new file mode 100644 index 0000000000..a4c8fc354e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp @@ -0,0 +1,69 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE]; +shared FLOAT_TYPE last_sum; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + uint subgroup_id = tid / SUBGROUP_SIZE; + + if (tid == 0) { + last_sum = 0; + } + + uint col = tid; + uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE); + for (int i = 0; i < num_iter; ++i) { + FLOAT_TYPE v = 0; + if (col < p.n_cols) { + v = FLOAT_TYPE(data_a[src_idx + col]); + } + v = subgroupInclusiveAdd(v); + + // Store the largest partial sum for each subgroup, then add the partials for all + // lower subgroups and the final partial sum from the previous iteration. + if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { + partial[subgroup_id] = v; + } + barrier(); + for (int j = 0; j < subgroup_id; ++j) { + v += partial[j]; + } + v += last_sum; + barrier(); + if (tid == BLOCK_SIZE - 1) { + last_sum = v; + } + if (col < p.n_cols) { + data_d[dst_idx + col] = D_TYPE(v); + } + col += BLOCK_SIZE; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp index bc22aa7bd7..13ba2e99dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -1,6 +1,7 @@ #version 450 #include "types.glsl" +#include "sum_rows.glsl" #extension GL_EXT_control_flow_attributes : enable @@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (push_constant) uniform parameter -{ - uint n_cols; - uint ne01, ne02; - uint nb01, nb02, nb03; - uint nb11, nb12, nb13; - float weight; - uint misalign_offsets; - uint ne0_12mp, ne0_12L; - uint ne0_1mp, ne0_1L; -} p; - -uint get_aoffset() { return p.misalign_offsets >> 16; } -uint get_doffset() { return p.misalign_offsets & 0xFFFF; } - -// see init_fastdiv_values in ggml-vulkan.cpp -uint fastdiv(uint n, uint mp, uint L) { - uint msbs, lsbs; - // msbs = mulhi(n, mp) - umulExtended(n, mp, msbs, lsbs); - return (msbs + n) >> L; -} - - shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl new file mode 100644 index 0000000000..2b841baa6b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl @@ -0,0 +1,25 @@ + +// vk_op_sum_rows_push_constants +layout (push_constant) uniform parameter +{ + uint n_cols; + uint ne01, ne02; + uint nb01, nb02, nb03; + uint nb11, nb12, nb13; + float weight; + uint misalign_offsets; + uint ne0_12mp, ne0_12L; + uint ne0_1mp, ne0_1L; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bc992068f8..164af35658 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -916,6 +916,7 @@ void process_shaders() { string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); + string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); for (std::string dim_str : {"", "_3d"}) { for (bool bda : {false, true}) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a5846a2393..b99345a2e9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", + "TOP_K", "LEAKY_RELU", "TRI", "FILL", @@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", + "top_k(x)", "leaky_relu(x)", "tri(x)", "fill(x, c)", @@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll( return result; } -// ggml_arange - -struct ggml_tensor * ggml_arange( - struct ggml_context * ctx, - float start, - float stop, - float step) { - GGML_ASSERT(stop > start); - - const int64_t steps = (int64_t) ceilf((stop - start) / step); - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); - - ggml_set_op_params_f32(result, 0, start); - ggml_set_op_params_f32(result, 1, stop); - ggml_set_op_params_f32(result, 2, step); - - result->op = GGML_OP_ARANGE; - - return result; -} - // ggml_timestep_embedding struct ggml_tensor * ggml_timestep_embedding( @@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort( struct ggml_tensor * a, enum ggml_sort_order order) { GGML_ASSERT(a->ne[0] <= INT32_MAX); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); ggml_set_op_params_i32(result, 0, (int32_t) order); @@ -5149,6 +5130,24 @@ struct ggml_tensor * ggml_argsort( return result; } +// ggml_argsort_top_k + +struct ggml_tensor * ggml_argsort_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k) { + GGML_ASSERT(a->ne[0] >= k); + + struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + // ggml_top_k struct ggml_tensor * ggml_top_k( @@ -5157,12 +5156,32 @@ struct ggml_tensor * ggml_top_k( int k) { GGML_ASSERT(a->ne[0] >= k); - struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]); - result = ggml_view_4d(ctx, result, - k, result->ne[1], result->ne[2], result->ne[3], - result->nb[1], result->nb[2], result->nb[3], - 0); + result->op = GGML_OP_TOP_K; + result->src[0] = a; + + return result; +} + +// ggml_arange + +struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step) { + GGML_ASSERT(stop > start); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); + + ggml_set_op_params_f32(result, 0, start); + ggml_set_op_params_f32(result, 1, stop); + ggml_set_op_params_f32(result, 2, step); + + result->op = GGML_OP_ARANGE; return result; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8bc558fe4b..6f5a742e04 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -25,6 +25,20 @@ class Keys: ALIGNMENT = "general.alignment" FILE_TYPE = "general.file_type" + # Recommended Sampler Parameters + SAMPLING_SEQUENCE = "general.sampling.sequence" + SAMPLING_TOP_K = "general.sampling.top_k" + SAMPLING_TOP_P = "general.sampling.top_p" + SAMPLING_MIN_P = "general.sampling.min_p" + SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability" + SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold" + SAMPLING_TEMP = "general.sampling.temp" + SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n" + SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat" + SAMPLING_MIROSTAT = "general.sampling.mirostat" + SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau" + SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta" + # Authorship Metadata NAME = "general.name" AUTHOR = "general.author" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a051daeeb1..57ca2035fe 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -4,6 +4,7 @@ import logging import os import shutil import struct +import sys import tempfile from dataclasses import dataclass from enum import Enum, auto @@ -372,8 +373,10 @@ class GGUFWriter: self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, ) -> None: - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ + (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) if self.use_temp_file and self.temp_file is None: fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) fp.seek(0) @@ -399,8 +402,10 @@ class GGUFWriter: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ + (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) file_id = -1 for i, tensors in enumerate(self.tensors): @@ -496,6 +501,42 @@ class GGUFWriter: def add_file_type(self, ftype: int) -> None: self.add_uint32(Keys.General.FILE_TYPE, ftype) + def add_sampling_sequence(self, sequence: str) -> None: + self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence) + + def add_sampling_top_k(self, top_k: int) -> None: + self.add_int32(Keys.General.SAMPLING_TOP_K, top_k) + + def add_sampling_top_p(self, top_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_TOP_P, top_p) + + def add_sampling_min_p(self, min_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIN_P, min_p) + + def add_sampling_xtc_probability(self, xtc_probability: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability) + + def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold) + + def add_sampling_temp(self, temp: float) -> None: + self.add_float32(Keys.General.SAMPLING_TEMP, temp) + + def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None: + self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n) + + def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None: + self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat) + + def add_sampling_mirostat(self, mirostat: int) -> None: + self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat) + + def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau) + + def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta) + def add_name(self, name: str) -> None: self.add_string(Keys.General.NAME, name) diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index 67efedbdbc..e0d478ce95 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -17,6 +17,20 @@ logger = logging.getLogger("metadata") @dataclass class Metadata: + # Recommended Sampler Parameters to be written to GGUF KV Store + sampling_sequence: Optional[str] = None + sampling_top_k: Optional[int] = None + sampling_top_p: Optional[float] = None + sampling_min_p: Optional[float] = None + sampling_xtc_probability: Optional[float] = None + sampling_xtc_threshold: Optional[float] = None + sampling_temp: Optional[float] = None + sampling_penalty_last_n: Optional[int] = None + sampling_penalty_repeat: Optional[float] = None + sampling_mirostat: Optional[int] = None + sampling_mirostat_tau: Optional[float] = None + sampling_mirostat_eta: Optional[float] = None + # Authorship Metadata to be written to GGUF KV Store name: Optional[str] = None author: Optional[str] = None @@ -54,15 +68,43 @@ class Metadata: model_card = Metadata.load_model_card(model_path) hf_params = Metadata.load_hf_parameters(model_path) + gen_config = Metadata.load_generation_config(model_path) # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # heuristics metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + if gen_config: + metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence) + metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k) + metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p) + metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p) + metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold) + metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp) + metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta) + # Metadata Override File Provided # This is based on LLM_KV_NAMES mapping in llama.cpp metadata_override = Metadata.load_metadata_override(metadata_override_path) + metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence) + metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k) + metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p) + metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p) + metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold) + metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp) + metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta) + metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) @@ -172,6 +214,23 @@ class Metadata: with open(config_path, "r", encoding="utf-8") as f: return json.load(f) + @staticmethod + def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]: + if model_path is None or not model_path.is_dir(): + return {} + + generation_config_path = model_path / "generation_config.json" + + if not generation_config_path.is_file(): + return {} + + try: + with open(generation_config_path, "r", encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + # not all models have valid generation_config.json + return {} + @staticmethod def id_to_title(string): # Convert capitalization into title form unless acronym or version number @@ -546,6 +605,32 @@ class Metadata: def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): assert self.name is not None + + if self.sampling_sequence is not None: + gguf_writer.add_sampling_sequence(self.sampling_sequence) + if self.sampling_top_k is not None: + gguf_writer.add_sampling_top_k(self.sampling_top_k) + if self.sampling_top_p is not None: + gguf_writer.add_sampling_top_p(self.sampling_top_p) + if self.sampling_min_p is not None: + gguf_writer.add_sampling_min_p(self.sampling_min_p) + if self.sampling_xtc_probability is not None: + gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability) + if self.sampling_xtc_threshold is not None: + gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold) + if self.sampling_temp is not None: + gguf_writer.add_sampling_temp(self.sampling_temp) + if self.sampling_penalty_last_n is not None: + gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n) + if self.sampling_penalty_repeat is not None: + gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat) + if self.sampling_mirostat is not None: + gguf_writer.add_sampling_mirostat(self.sampling_mirostat) + if self.sampling_mirostat_tau is not None: + gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau) + if self.sampling_mirostat_eta is not None: + gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta) + gguf_writer.add_name(self.name) if self.author is not None: diff --git a/include/llama.h b/include/llama.h index 8547226ff2..b52eaacfa7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -246,6 +246,21 @@ extern "C" { LLAMA_KV_OVERRIDE_TYPE_STR, }; + enum llama_model_meta_key { + LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_K, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_P, + LLAMA_MODEL_META_KEY_SAMPLING_MIN_P, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD, + LLAMA_MODEL_META_KEY_SAMPLING_TEMP, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA, + }; + struct llama_model_kv_override { enum llama_model_kv_override_type tag; @@ -518,6 +533,9 @@ extern "C" { // Get the number of metadata key/value pairs LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model); + // Get sampling metadata key name. Returns nullptr if the key is invalid + LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key); + // Get metadata key name by index LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 439c47a799..637f4cdc18 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -16,7 +16,7 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h", + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fc6cddc92f..7ef87acf1b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -114,19 +114,31 @@ static const std::map LLM_ARCH_NAMES = { }; static const std::map LLM_KV_NAMES = { - { LLM_KV_GENERAL_TYPE, "general.type" }, - { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, - { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, - { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, - { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, - { LLM_KV_GENERAL_NAME, "general.name" }, - { LLM_KV_GENERAL_AUTHOR, "general.author" }, - { LLM_KV_GENERAL_VERSION, "general.version" }, - { LLM_KV_GENERAL_URL, "general.url" }, - { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, - { LLM_KV_GENERAL_LICENSE, "general.license" }, - { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, - { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + { LLM_KV_GENERAL_TYPE, "general.type" }, + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, + { LLM_KV_GENERAL_SAMPLING_SEQUENCE, "general.sampling.sequence" }, + { LLM_KV_GENERAL_SAMPLING_TOP_K, "general.sampling.top_k" }, + { LLM_KV_GENERAL_SAMPLING_TOP_P, "general.sampling.top_p" }, + { LLM_KV_GENERAL_SAMPLING_MIN_P, "general.sampling.min_p" }, + { LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability" }, + { LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, "general.sampling.xtc_threshold" }, + { LLM_KV_GENERAL_SAMPLING_TEMP, "general.sampling.temp" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, "general.sampling.penalty_last_n" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, "general.sampling.penalty_repeat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT, "general.sampling.mirostat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, "general.sampling.mirostat_tau" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, "general.sampling.mirostat_eta" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_VERSION, "general.version" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 02a1c2dc25..9ad3157bf6 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -123,6 +123,18 @@ enum llm_kv { LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, LLM_KV_GENERAL_FILE_TYPE, + LLM_KV_GENERAL_SAMPLING_SEQUENCE, + LLM_KV_GENERAL_SAMPLING_TOP_K, + LLM_KV_GENERAL_SAMPLING_TOP_P, + LLM_KV_GENERAL_SAMPLING_MIN_P, + LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, + LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, + LLM_KV_GENERAL_SAMPLING_TEMP, + LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, + LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_VERSION, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 70a3ec62df..2aa6d52a24 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1248,7 +1248,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // make the outputs have the same order they had in the user-provided batch // note: this is mostly relevant for recurrent models atm - if (!sorted_output) { + if (!sorted_output && n_outputs > 1) { GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 650e40ec6f..1d012e09ab 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // organize experts into n_expert_groups ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens] - ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] + ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens] // get top n_group_used expert groups group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens] group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens] - ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] + ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] cb(expert_groups, "ffn_moe_group_topk", il); // mask out the other groups @@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( } // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 35179a98e0..a042ea9632 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7687,6 +7687,24 @@ int32_t llama_model_meta_count(const llama_model * model) { return (int)model->gguf_kv.size(); } +const char * llama_model_meta_key_str(llama_model_meta_key key) { + switch (key) { + case LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE: return "general.sampling.sequence"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_K: return "general.sampling.top_k"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_P: return "general.sampling.top_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIN_P: return "general.sampling.min_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY: return "general.sampling.xtc_probability"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD: return "general.sampling.xtc_threshold"; + case LLAMA_MODEL_META_KEY_SAMPLING_TEMP: return "general.sampling.temp"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N: return "general.sampling.penalty_last_n"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT: return "general.sampling.penalty_repeat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT: return "general.sampling.mirostat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU: return "general.sampling.mirostat_tau"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA: return "general.sampling.mirostat_eta"; + default: return nullptr; + } +} + int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) { if (i < 0 || i >= (int)model->gguf_kv.size()) { if (buf_size > 0) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ce8c068d7a..2fc7d4c3c7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -39,6 +39,7 @@ #include #include #include +#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) { return mse_a_b / mse_a_0; } +// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap) +static double jdst(const int32_t * a, const int32_t * b, size_t n) { + std::unordered_map set_a; + std::unordered_map set_b; + + for (size_t i = 0; i < n; ++i) { + set_a[a[i]]++; + set_b[b[i]]++; + } + + size_t diff = 0; + + for (const auto & p : set_a) { + const int64_t na = p.second; + const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0; + + diff += std::abs(na - nb); + } + + for (const auto & p : set_b) { + if (set_a.find(p.first) == set_a.end()) { + diff += p.second; + } + } + + return (double) diff / (2*n); +} + // maximum absolute asymmetry between a and b // asymmetry: (a - b) / (a + b) // This is more stable than relative error if one of the values fluctuates towards zero. @@ -1051,6 +1080,14 @@ struct test_case { return 1e-4; } + virtual double max_err() { + return max_nmse_err(); + } + + virtual double err(const float * a, const float * b, size_t n) { + return nmse(a, b, n); + } + virtual float grad_eps() { return 1e-1f; } @@ -1257,16 +1294,16 @@ struct test_case { // compare struct callback_userdata { bool ok; - double max_err; + test_case * tc; ggml_backend_t backend1; ggml_backend_t backend2; }; callback_userdata ud { true, - max_nmse_err(), + this, backend1, - backend2 + backend2, }; auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { @@ -1314,9 +1351,9 @@ struct test_case { } } - double err = nmse(f1.data(), f2.data(), f1.size()); - if (err > ud->max_err) { - printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + double err = ud->tc->err(f1.data(), f2.data(), f1.size()); + if (err > ud->tc->max_err()) { + printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err()); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -4943,7 +4980,71 @@ struct test_argsort : public test_case { } }; -struct test_topk_moe: public test_case { +// GGML_OP_TOP_K +struct test_top_k : public test_case { + const ggml_type type; + const std::array ne; + const int k; + + std::string vars() override { + return VARS_TO_STR3(type, ne, k); + } + + test_top_k(ggml_type type = GGML_TYPE_F32, + std::array ne = {16, 10, 10, 10}, + int k = 4) + : type(type), ne(ne), k(k) {} + + double max_err() override { + return 0.0; + } + + double err(const float * a, const float * b, size_t n) override { + std::vector ia(n); + std::vector ib(n); + + double diff = 0.0f; + + for (size_t i = 0; i < n; i++) { + ia[i] = (int32_t) a[i]; + ib[i] = (int32_t) b[i]; + + // penalize the result if the data is not integer valued + diff += std::fabs(a[i] - ia[i]); + diff += std::fabs(b[i] - ib[i]); + } + + return diff + jdst(ia.data(), ib.data(), n); + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_top_k(ctx, a, k); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // initialize with unique values to avoid ties + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } + } +}; + +struct test_topk_moe : public test_case { const std::array ne; const int n_expert_used; const bool with_norm; @@ -4976,7 +5077,7 @@ struct test_topk_moe: public test_case { ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); - ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] @@ -7534,6 +7635,23 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) } + for (int k : {1, 2, 3, 7, 15}) { + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k)); + } + + // exhaustive top_k tests + //for (int i = 1; i < 9999; ++i) { + // test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1)); + //} + for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); @@ -7859,6 +7977,9 @@ static std::vector> make_test_cases_perf() { } } + // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 + test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) { @@ -7911,6 +8032,7 @@ static std::vector> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40)); return test_cases; } diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz deleted file mode 100644 index ff8e18007c..0000000000 Binary files a/tools/server/public/index.html.gz and /dev/null differ diff --git a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte index 7e83d30f13..176a98b212 100644 --- a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte @@ -8,6 +8,7 @@ import rehypeKatex from 'rehype-katex'; import rehypeStringify from 'rehype-stringify'; import { copyCodeToClipboard } from '$lib/utils/copy'; + import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer'; import { preprocessLaTeX } from '$lib/utils/latex-protection'; import { browser } from '$app/environment'; import '$styles/katex-custom.scss'; @@ -60,6 +61,7 @@ .use(remarkRehype) // Convert Markdown AST to rehype .use(rehypeKatex) // Render math using KaTeX .use(rehypeHighlight) // Add syntax highlighting + .use(rehypeRestoreTableHtml) // Restore limited HTML (e.g.,
,
    ) inside Markdown tables .use(rehypeStringify); // Convert to HTML string }); diff --git a/tools/server/webui/src/lib/constants/table-html-restorer.ts b/tools/server/webui/src/lib/constants/table-html-restorer.ts new file mode 100644 index 0000000000..e5d5b12011 --- /dev/null +++ b/tools/server/webui/src/lib/constants/table-html-restorer.ts @@ -0,0 +1,20 @@ +/** + * Matches
    ,
    ,
    tags (case-insensitive). + * Used to detect line breaks in table cell text content. + */ +export const BR_PATTERN = //gi; + +/** + * Matches a complete
      ...
    block. + * Captures the inner content (group 1) for further
  • extraction. + * Case-insensitive, allows multiline content. + */ +export const LIST_PATTERN = /^
      ([\s\S]*)<\/ul>$/i; + +/** + * Matches individual
    • ...
    • elements within a list. + * Captures the inner content (group 1) of each list item. + * Non-greedy to handle multiple consecutive items. + * Case-insensitive, allows multiline content. + */ +export const LI_PATTERN = /
    • ([\s\S]*?)<\/li>/gi; diff --git a/tools/server/webui/src/lib/markdown/table-html-restorer.ts b/tools/server/webui/src/lib/markdown/table-html-restorer.ts new file mode 100644 index 0000000000..918aa46811 --- /dev/null +++ b/tools/server/webui/src/lib/markdown/table-html-restorer.ts @@ -0,0 +1,181 @@ +/** + * Rehype plugin to restore limited HTML elements inside Markdown table cells. + * + * ## Problem + * The remark/rehype pipeline neutralizes inline HTML as literal text + * (remarkLiteralHtml) so that XML/HTML snippets in LLM responses display + * as-is instead of being rendered. This causes
      and
        markup in + * table cells to show as plain text. + * + * ## Solution + * This plugin traverses the HAST post-conversion, parses whitelisted HTML + * patterns from text nodes, and replaces them with actual HAST element nodes + * that will be rendered as real HTML. + * + * ## Supported HTML + * - `
        ` / `
        ` / `
        ` - Line breaks (inline) + * - `
        • ...
        ` - Unordered lists (block) + * + * ## Key Implementation Details + * + * ### 1. Sibling Combination (Critical) + * The Markdown pipeline may fragment content across multiple text nodes and `
        ` + * elements. For example, `
        • a
        ` might arrive as: + * - Text: `"
          "` + * - Element: `
          ` + * - Text: `"
        • a
        "` + * + * We must combine consecutive text nodes and `
        ` elements into a single string + * before attempting to parse list markup. Without this, list detection fails. + * + * ### 2. visitParents for Deep Traversal + * Table cell content may be wrapped in intermediate elements (e.g., `

        ` tags). + * Using `visitParents` instead of direct child iteration ensures we find text + * nodes at any depth within the cell. + * + * ### 3. Reference Comparison for No-Op Detection + * When checking if `
        ` expansion changed anything, we compare: + * `expanded.length !== 1 || expanded[0] !== textNode` + * + * This catches both cases: + * - Multiple nodes created (text was split) + * - Single NEW node created (original had only `
        `, now it's an element) + * + * A simple `length > 1` check would miss the single `
        ` case. + * + * ### 4. Strict List Validation + * `parseList()` rejects malformed markup by checking for garbage text between + * `

      • ` elements. This prevents creating broken DOM from partial matches like + * `
          garbage
        • a
        `. + * + * ### 5. Newline Substitution for `
        ` in Combined String + * When combining siblings, existing `
        ` elements become `\n` in the combined + * string. This allows list content to span visual lines while still being parsed + * as a single unit. + * + * @example + * // Input Markdown: + * // | Feature | Notes | + * // |---------|-------| + * // | Multi-line | First
        Second | + * // | List |
        • A
        • B
        | + * // + * // Without this plugin:
        and
          render as literal text + * // With this plugin:
          becomes line break,
            becomes actual list + */ + +import type { Plugin } from 'unified'; +import type { Element, ElementContent, Root, Text } from 'hast'; +import { visit } from 'unist-util-visit'; +import { visitParents } from 'unist-util-visit-parents'; +import { BR_PATTERN, LIST_PATTERN, LI_PATTERN } from '$lib/constants/table-html-restorer'; + +/** + * Expands text containing `
            ` tags into an array of text nodes and br elements. + */ +function expandBrTags(value: string): ElementContent[] { + const matches = [...value.matchAll(BR_PATTERN)]; + if (!matches.length) return [{ type: 'text', value } as Text]; + + const result: ElementContent[] = []; + let cursor = 0; + + for (const m of matches) { + if (m.index! > cursor) { + result.push({ type: 'text', value: value.slice(cursor, m.index) } as Text); + } + result.push({ type: 'element', tagName: 'br', properties: {}, children: [] } as Element); + cursor = m.index! + m[0].length; + } + + if (cursor < value.length) { + result.push({ type: 'text', value: value.slice(cursor) } as Text); + } + + return result; +} + +/** + * Parses a `
            • ...
            ` string into a HAST element. + * Returns null if the markup is malformed or contains unexpected content. + */ +function parseList(value: string): Element | null { + const match = value.trim().match(LIST_PATTERN); + if (!match) return null; + + const body = match[1]; + const items: ElementContent[] = []; + let cursor = 0; + + for (const liMatch of body.matchAll(LI_PATTERN)) { + // Reject if there's non-whitespace between list items + if (body.slice(cursor, liMatch.index!).trim()) return null; + + items.push({ + type: 'element', + tagName: 'li', + properties: {}, + children: expandBrTags(liMatch[1] ?? '') + } as Element); + + cursor = liMatch.index! + liMatch[0].length; + } + + // Reject if no items found or trailing garbage exists + if (!items.length || body.slice(cursor).trim()) return null; + + return { type: 'element', tagName: 'ul', properties: {}, children: items } as Element; +} + +/** + * Processes a single table cell, restoring HTML elements from text content. + */ +function processCell(cell: Element) { + visitParents(cell, 'text', (textNode: Text, ancestors) => { + const parent = ancestors[ancestors.length - 1]; + if (!parent || parent.type !== 'element') return; + + const parentEl = parent as Element; + const siblings = parentEl.children as ElementContent[]; + const startIndex = siblings.indexOf(textNode as ElementContent); + if (startIndex === -1) return; + + // Combine consecutive text nodes and
            elements into one string + let combined = ''; + let endIndex = startIndex; + + for (let i = startIndex; i < siblings.length; i++) { + const sib = siblings[i]; + if (sib.type === 'text') { + combined += (sib as Text).value; + endIndex = i; + } else if (sib.type === 'element' && (sib as Element).tagName === 'br') { + combined += '\n'; + endIndex = i; + } else { + break; + } + } + + // Try parsing as list first (replaces entire combined range) + const list = parseList(combined); + if (list) { + siblings.splice(startIndex, endIndex - startIndex + 1, list); + return; + } + + // Otherwise, just expand
            tags in this text node + const expanded = expandBrTags(textNode.value); + if (expanded.length !== 1 || expanded[0] !== textNode) { + siblings.splice(startIndex, 1, ...expanded); + } + }); +} + +export const rehypeRestoreTableHtml: Plugin<[], Root> = () => (tree) => { + visit(tree, 'element', (node: Element) => { + if (node.tagName === 'td' || node.tagName === 'th') { + processCell(node); + } + }); +}; diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index 8e0f8064f7..0fa1cd9831 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -31,13 +31,16 @@ if (LLAMA_BUILD_BORINGSSL) message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}") - include(FetchContent) - FetchContent_Declare( - boringssl + set(BORINGSSL_ARGS GIT_REPOSITORY ${BORINGSSL_GIT} GIT_TAG ${BORINGSSL_VERSION} - PATCH_COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_SOURCE_DIR}/patch-boringssl.cmake" ) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL) + endif() + + include(FetchContent) + FetchContent_Declare(boringssl ${BORINGSSL_ARGS}) set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) set(SAVED_BUILD_TESTING ${BUILD_TESTING}) @@ -45,7 +48,15 @@ if (LLAMA_BUILD_BORINGSSL) set(BUILD_SHARED_LIBS OFF) set(BUILD_TESTING OFF) - FetchContent_MakeAvailable(boringssl) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + FetchContent_MakeAvailable(boringssl) + else() + FetchContent_GetProperties(boringssl) + if(NOT boringssl_POPULATED) + FetchContent_Populate(boringssl) + add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() + endif() set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) set(BUILD_TESTING ${SAVED_BUILD_TESTING}) diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 5432db69b4..b86e6a2310 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1087,22 +1087,30 @@ int getaddrinfo_with_timeout(const char *node, const char *service, // Fallback implementation using thread-based timeout for other Unix systems struct GetAddrInfoState { + ~GetAddrInfoState() { + if (info) { freeaddrinfo(info); } + } + std::mutex mutex; std::condition_variable result_cv; bool completed = false; int result = EAI_SYSTEM; - std::string node = node; - std::string service = service; - struct addrinfo hints = hints; + std::string node; + std::string service; + struct addrinfo hints; struct addrinfo *info = nullptr; }; // Allocate on the heap, so the resolver thread can keep using the data. auto state = std::make_shared(); + state->node = node; + state->service = service; + state->hints = *hints; - std::thread resolve_thread([=]() { - auto thread_result = getaddrinfo( - state->node.c_str(), state->service.c_str(), hints, &state->info); + std::thread resolve_thread([state]() { + auto thread_result = + getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints, + &state->info); std::lock_guard lock(state->mutex); state->result = thread_result; @@ -1120,6 +1128,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service, // Operation completed within timeout resolve_thread.join(); *res = state->info; + state->info = nullptr; // Pass ownership to caller return state->result; } else { // Timeout occurred @@ -4970,7 +4979,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection, if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } // Prepare additional headers - if (close_connection || req.get_header_value("Connection") == "close") { + if (close_connection || req.get_header_value("Connection") == "close" || + 400 <= res.status) { // Don't leave connections open after errors res.set_header("Connection", "close"); } else { std::string s = "timeout="; @@ -5173,7 +5183,11 @@ bool Server::read_content_core( size_t /*len*/) { return receiver(buf, n); }; } - if (req.method == "DELETE" && !req.has_header("Content-Length")) { + // RFC 7230 Section 3.3.3: If this is a request message and none of the above + // are true (no Transfer-Encoding and no Content-Length), then the message + // body length is zero (no message body is present). + if (!req.has_header("Content-Length") && + !detail::is_chunked_transfer_encoding(req.headers)) { return true; } @@ -5681,8 +5695,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, // Check if the request URI doesn't exceed the limit if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); res.status = StatusCode::UriTooLong_414; output_error_log(Error::ExceedUriMaxLength, &req); return write_response(strm, close_connection, req, res); @@ -6666,11 +6678,13 @@ bool ClientImpl::write_request(Stream &strm, Request &req, return true; } -std::unique_ptr ClientImpl::send_with_content_provider( +std::unique_ptr +ClientImpl::send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error) { + const std::string &content_type, ContentReceiver content_receiver, + Error &error) { if (!content_type.empty()) { req.set_header("Content-Type", content_type); } #ifdef CPPHTTPLIB_ZLIB_SUPPORT @@ -6743,15 +6757,24 @@ std::unique_ptr ClientImpl::send_with_content_provider( } } + if (content_receiver) { + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + } + auto res = detail::make_unique(); return send(req, *res, error) ? std::move(res) : nullptr; } -Result ClientImpl::send_with_content_provider( +Result ClientImpl::send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress) { + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress) { Request req; req.method = method; req.headers = headers; @@ -6763,9 +6786,10 @@ Result ClientImpl::send_with_content_provider( auto error = Error::Success; - auto res = send_with_content_provider( + auto res = send_with_content_provider_and_receiver( req, body, content_length, std::move(content_provider), - std::move(content_provider_without_length), content_type, error); + std::move(content_provider_without_length), content_type, + std::move(content_receiver), error); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, @@ -7094,6 +7118,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7102,6 +7135,15 @@ Result ClientImpl::Post(const std::string &path, progress); } +Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7142,17 +7184,18 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7160,18 +7203,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7181,10 +7246,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "POST", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7246,6 +7311,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7254,6 +7328,15 @@ Result ClientImpl::Put(const std::string &path, progress); } +Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7294,17 +7377,18 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7312,18 +7396,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7333,10 +7439,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PUT", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7400,6 +7506,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7408,6 +7523,15 @@ Result ClientImpl::Patch(const std::string &path, progress); } +Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7448,18 +7572,18 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body, - content_length, nullptr, nullptr, - content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7467,18 +7591,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7488,10 +7634,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PATCH", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -8883,12 +9029,28 @@ Result Client::Post(const std::string &path, size_t content_length, return cli_->Post(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Post(path, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8897,6 +9059,15 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8904,6 +9075,14 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Params ¶ms) { return cli_->Post(path, params); } @@ -8938,8 +9117,8 @@ Result Client::Post(const std::string &path, const Headers &headers, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress) { - return cli_->Post(path, headers, body, content_type, content_receiver, - progress); + return cli_->Post(path, headers, body, content_type, + std::move(content_receiver), progress); } Result Client::Put(const std::string &path) { return cli_->Put(path); } @@ -8976,12 +9155,28 @@ Result Client::Put(const std::string &path, size_t content_length, return cli_->Put(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Put(path, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8990,6 +9185,15 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8997,6 +9201,14 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Params ¶ms) { return cli_->Put(path, params); } @@ -9072,12 +9284,28 @@ Result Client::Patch(const std::string &path, size_t content_length, return cli_->Patch(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Patch(path, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -9086,6 +9314,15 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -9093,6 +9330,14 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Params ¶ms) { return cli_->Patch(path, params); } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 083f795036..c9bd9fd86b 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.27.0" -#define CPPHTTPLIB_VERSION_NUM "0x001B00" +#define CPPHTTPLIB_VERSION "0.28.0" +#define CPPHTTPLIB_VERSION_NUM "0x001C00" /* * Platform compatibility check @@ -257,6 +257,7 @@ using socklen_t = int; #include #ifdef __linux__ #include +#undef _res // Undefine _res macro to avoid conflicts with user code (#2278) #endif #include #include @@ -1421,14 +1422,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1439,14 +1444,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1457,14 +1466,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1712,17 +1725,19 @@ private: template void setup_redirect_client(ClientType &client); bool handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); - std::unique_ptr send_with_content_provider( + std::unique_ptr send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error); - Result send_with_content_provider( + const std::string &content_type, ContentReceiver content_receiver, + Error &error); + Result send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress); + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress); ContentProviderWithoutLength get_multipart_content_provider( const std::string &boundary, const UploadFormDataItems &items, const FormDataProviderItems &provider_items) const; @@ -1775,14 +1790,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1793,14 +1812,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1811,14 +1834,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); diff --git a/vendor/cpp-httplib/patch-boringssl.cmake b/vendor/cpp-httplib/patch-boringssl.cmake deleted file mode 100644 index 2914e1dddb..0000000000 --- a/vendor/cpp-httplib/patch-boringssl.cmake +++ /dev/null @@ -1,6 +0,0 @@ -# Remove bssl -file(READ "CMakeLists.txt" content) -string(REPLACE "add_executable(bssl" "#add_executable(bssl" content "${content}") -string(REPLACE "target_link_libraries(bssl" "#target_link_libraries(bssl" content "${content}") -string(REPLACE "install(TARGETS bssl" "#install(TARGETS bssl" content "${content}") -file(WRITE "CMakeLists.txt" "${content}")