Merge branch 'master' into riscv

This commit is contained in:
Taimur Ahmad 2025-11-27 00:35:21 +05:00 committed by GitHub
commit 2786a97ef0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 2055 additions and 503 deletions

View File

@ -10061,6 +10061,25 @@ class LazyTorchTensor(gguf.LazyBase):
torch.uint8: np.uint8, torch.uint8: np.uint8,
} }
# only used when byteswapping data. Only correct size is needed
_dtype_byteswap_map: dict[torch.dtype, type] = {
torch.float64: np.float64,
torch.float32: np.float32,
torch.bfloat16: np.float16,
torch.float16: np.float16,
torch.int64: np.int64,
torch.uint64: np.uint64,
torch.int32: np.int32,
torch.uint32: np.uint32,
torch.int16: np.int16,
torch.uint16: np.uint16,
torch.int8: np.int8,
torch.uint8: np.uint8,
torch.bool: np.uint8,
torch.float8_e4m3fn: np.uint8,
torch.float8_e5m2: np.uint8,
}
# used for safetensors slices # used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
@ -10104,8 +10123,14 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod @classmethod
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor: def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[tensor.dtype] dtype = cls._dtype_str_map[tensor.dtype]
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape) numpy_dtype = cls._dtype_byteswap_map[dtype]
return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype] dtype = cls._dtype_str_map[t.dtype]
shape = t.shape shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r)) lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
@ -10113,10 +10138,16 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod @classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[remote_tensor.dtype] dtype = cls._dtype_str_map[remote_tensor.dtype]
numpy_dtype = cls._dtype_byteswap_map[dtype]
shape = remote_tensor.shape shape = remote_tensor.shape
meta = cls.meta_with_dtype_and_shape(dtype, shape) meta = cls.meta_with_dtype_and_shape(dtype, shape)
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape))
return cast(torch.Tensor, lazy) return cast(torch.Tensor, lazy)
@classmethod @classmethod

View File

@ -530,6 +530,7 @@ extern "C" {
GGML_OP_ARANGE, GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT, GGML_OP_ARGSORT,
GGML_OP_TOP_K,
GGML_OP_LEAKY_RELU, GGML_OP_LEAKY_RELU,
GGML_OP_TRI, GGML_OP_TRI,
GGML_OP_FILL, GGML_OP_FILL,
@ -2258,18 +2259,25 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
enum ggml_sort_order order); enum ggml_sort_order order);
// similar to ggml_top_k but implemented as `argsort` + `view`
GGML_API struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
// top k elements per row
// note: the resulting top k indices are in no particular order
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
GGML_API struct ggml_tensor * ggml_arange( GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx, struct ggml_context * ctx,
float start, float start,
float stop, float stop,
float step); float step);
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
#define GGML_KQ_MASK_PAD 64 #define GGML_KQ_MASK_PAD 64
// q: [n_embd_k, n_batch, n_head, ne3 ] // q: [n_embd_k, n_batch, n_head, ne3 ]

View File

@ -2207,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx,
} }
/** /**
* @brief Initializes and caches sine/cosine positional encoding values * @brief Initializes and caches all intermediate tensors required for RoPE
* (used in RoPE, Rotary Position Embedding) for attention layers. * (Rotary Position Embedding), including support for Yarn, mRoPE,
* i-mRoPE, Neox repeat strategy, independent sectors, frequency factors
* and multi-section rotary groups.
* *
* This function computes and caches the sin/cos values of * This function computes and caches the per-dimension θ coefficients used for
* θ = position * theta_scale for RoPE encoding. The cache is shared * Q/K rotary embedding. The cache is shared across layers, and recomputed only
* across attention layers, and only the first attention layer will * when any dependent parameter changes.
* trigger initialization. The cache includes repeated sin/cos values
* with different repeat methods depending on the @param is_neox flag.
* *
* Steps performed by this function: * The function now supports:
* 1. Identify whether the target tensor belongs to Q/K in attention * - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor)
* and restrict computation to the first layer only. * - Per-dimension independent sector exponent rules (indep_sects + sections[])
* 2. Initialize the theta scale array (arange power freq scaling). * - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope)
* 3. Allocate sin/cos caches if the max prompt length increases. * - Frequency factor division (src2)
* 4. Compute θ = position * theta_scale. * - Neox / normal repeat expansion modes
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled.
* *
* @param ctx The CANN backend context, holding memory pool, * @param ctx CANN backend context, containing memory pool,
* stream, and persistent buffers for rope init/cache. * cached buffers, and runtime stream.
* @param dst The destination ggml_tensor whose computation * @param dst Destination ggml_tensor whose computation
* depends on the RoPE values (usually Qcur/Kcur). * depends on RoPE (typically Qcur or Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values. * @param corr_dims [low, high] Yarn correction range.
* @param freq_scale Frequency scaling factor, applied to theta scale. * @param ext_factor Yarn extrapolation strength. 0 = disabled.
* @param attn_factor Attention scaling factor, applied to sin/cos. * @param theta_scale Base multiplier for per-dimension θ exponent.
* @param is_neox Whether to use Neox-style repeat strategy * @param freq_scale Global frequency scaling factor.
* (dim expansion vs repeat_interleave). * @param attn_factor Optional scaling applied to sin/cos (if needed).
* @param is_neox Whether to use Neox-style dimension interleave.
* @param sections 4-way sector sizes for independent-section RoPE
* and multi-section mRoPE (t/h/w/e).
* @param mrope_used Whether to enable multi-section rotary embedding.
* @param is_imrope Whether to apply interleaved mRoPE rules.
* @param indep_sects Whether each dimension runs independent exponent
* resets based on @p sections.
*/ */
static void aclnn_cache_init(ggml_backend_cann_context & ctx, static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
ggml_tensor * dst, ggml_tensor * dst,
float * corr_dims, float * corr_dims,
float ext_factor, float ext_factor,
float theta_scale, float theta_scale,
float freq_scale, float freq_scale,
float attn_factor, float attn_factor,
bool is_neox) { bool is_neox,
int sections[4],
bool mrope_used,
bool is_imrope,
bool indep_sects) {
ggml_tensor * src0 = dst->src[0]; // input ggml_tensor * src0 = dst->src[0]; // input
ggml_tensor * src1 = dst->src[1]; // position ggml_tensor * src1 = dst->src[1]; // position
ggml_tensor * src2 = dst->src[2]; // freq_factors ggml_tensor * src2 = dst->src[2]; // freq_factors
if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor && int64_t theta_scale_length = src0->ne[0] / 2;
ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale && int64_t position_length = dst->ne[2];
ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) {
// TODO: check theta_scale_length and position_length.
if (src2 == nullptr && ctx.rope_cache.cached &&
ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor,
is_neox, indep_sects, mrope_used, is_imrope, sections)) {
// use cache. // use cache.
return; return;
} }
int64_t theta_scale_length = src0->ne[0] / 2; // Step0: calculate tensor shape.
int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) }; size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float),
theta_scale_length * sizeof(float) };
GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->type == GGML_TYPE_I32);
int64_t position_length = src1->ne[0];
int64_t position_ne[] = { 1, 1, position_length, 1 }; int64_t position_ne[] = { 1, 1, position_length, 1 };
size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 }; int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 };
size_t theta_nb[GGML_MAX_DIMS]; size_t cache_nb[GGML_MAX_DIMS];
theta_nb[0] = sizeof(float); cache_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1];
} }
// theta_scale arange, [0,1,...,ne00/2 - 1] // Step1: Compute the coefficient of theta. During the cache_init process, aside from
// (1) multiplying by the position,
// (2) dividing by freq_factors,
// (3) computing the sine and cosine,
// the other parameters used in the computation generally do not change in most scenarios.
// Therefore, we can first compute this part of the result and then cache it.
// Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor.
acl_tensor_ptr acl_theta_scale_tensor; acl_tensor_ptr acl_theta_scale_tensor;
// cache theta scale bool theta_scale_updated = false;
if (ctx.rope_cache.theta_scale_length != theta_scale_length || if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale ||
// theta_scale and freq_scale should not change during the current token inference process, ctx.rope_cache.indep_sects != indep_sects) {
// so we can directly use == here instead of comparing the absolute difference. theta_scale_updated = true;
ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) { if (ctx.rope_cache.theta_scale_exp_host != nullptr) {
ctx.rope_cache.theta_scale_length = theta_scale_length; free(ctx.rope_cache.theta_scale_exp_host);
}
ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float));
GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr);
if (!indep_sects) {
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
} else {
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) {
ctx.rope_cache.theta_scale_exp_host[i] = 1;
continue;
}
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
}
if (ctx.rope_cache.theta_scale_cache != nullptr) { if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
@ -2286,18 +2328,23 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ACL_MEM_MALLOC_HUGE_FIRST)); ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, 1); theta_scale_ne, theta_scale_nb, 1);
}
float start = 0; // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
float step = 1; bool yarn_ramp_tensor_updated = false;
float stop = theta_scale_length;
float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements);
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor; acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0) { if (ext_factor != 0 &&
// TODO: check more parameter.
(ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
yarn_ramp_tensor_updated = true;
// -rope_yarn_ramp // -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low); // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1; // return MIN(1, MAX(0, y)) - 1;
@ -2313,8 +2360,8 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT); acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT); acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(), aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length);
acl_yarn_ramp_tensor.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
@ -2336,24 +2383,83 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
} }
// power // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(),
acl_theta_scale_tensor.get());
if (ext_factor != 0) { if (ext_factor != 0) {
if (theta_scale_updated || yarn_ramp_tensor_updated) {
theta_scale_updated = true;
aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get()); aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get());
} else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
} }
} else { } else {
// use cache if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) {
theta_scale_updated = true;
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
}
}
// Nothing changed, use cache.
if (!theta_scale_updated) {
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
} }
// Step 1.4: prepare select index if mrope
acl_tensor_ptr position_select_index_tensor;
if (mrope_used) {
if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] ||
ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] ||
ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) {
if (ctx.rope_cache.position_select_index_host != nullptr) {
free(ctx.rope_cache.position_select_index_host);
}
ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int));
GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr);
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
// t,h,w,e
for (int i = 0; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (is_imrope) { // qwen3vl apply interleaved mrope
if (sector % 3 == 1 && sector < 3 * sections[1]) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
ctx.rope_cache.position_select_index_host[i] = 0;
} else {
ctx.rope_cache.position_select_index_host[i] = 3;
}
} else {
if (sector >= sections[0] && sector < sec_w) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector >= sec_w && sector < sec_e) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector >= sec_e) {
ctx.rope_cache.position_select_index_host[i] = 3;
} else {
ctx.rope_cache.position_select_index_host[i] = 0;
}
}
}
if (ctx.rope_cache.position_select_index != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
}
position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32,
sizeof(int), theta_scale_ne, theta_scale_nb, 1);
}
// Step2: divide by freq_factors
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool()); ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
// freq_factors
if (src2) { if (src2) {
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float)); freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
void * freq_fac_res_ptr = freq_fac_res_allocator.get(); void * freq_fac_res_ptr = freq_fac_res_allocator.get();
@ -2366,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor); std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
} }
// Step3: prepare position_tensor
acl_tensor_ptr acl_position_tensor;
ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool());
if (mrope_used) {
// Step3.1: select current position;
// position :
// pos1: [[0, 1 ,2 ,3 ],
// pos2: [4, 5 ,6 ,7 ],
// pos3: [8, 9 ,10,11],
// pos4: [12,13,14,15] ]
//
// select index = [0, 1, 2, 2, 1, 0]
//
// selected_tensor:
// [[0, 1 ,2 ,3 ],
// [4, 5 ,6 ,7 ],
// [8, 9 ,10,11],
// [8, 9 ,10,11],
// [4, 5 ,6 ,7 ],
// [0, 1 ,2 ,3 ]]
//
// transpose, from [seq_len:dims] to [dims:seq_len]
// [0, 4, 8 ,8 ,4, 0],
// [1, 5, 9, 9, 5, 1],
// [2, 6, 10,10,6 ,2],
// [3, 7, 11,11,7 3 ]]
//
// multipy by theta_scale_tensor
// [theta_scale^0, theta_scale^1, ..., theta_scale ^ n]
int64_t mrope_position_ne[] = { position_length, 4 };
size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) };
acl_tensor_ptr mrope_position =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
mrope_position_ne, mrope_position_nb, 2);
// selected position tensor's shape is a transpose of cache tensor.
int64_t selected_position_ne[] = { position_length, theta_scale_length };
size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) };
mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float));
void * mrope_position_buffer = mrope_position_acllocator.get();
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(),
acl_position_tensor.get());
// transpose
int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 };
size_t transposed_nb[GGML_MAX_DIMS];
transposed_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1];
}
std::swap(transposed_ne[0], transposed_ne[2]);
std::swap(transposed_nb[0], transposed_nb[2]);
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS);
} else {
// auto bcast.
acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
position_ne, position_nb, GGML_MAX_DIMS);
}
// Step4: multiply by the position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// Step5: calculate sin cos.
// init sin_repeat && cos_repeat, only to accelerate first layer on each device // init sin_repeat && cos_repeat, only to accelerate first layer on each device
if (position_length > ctx.rope_cache.position_length) { if (position_length > ctx.rope_cache.position_length) {
ctx.rope_cache.position_length = position_length; ctx.rope_cache.position_length = position_length;
@ -2382,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
} }
// position
acl_tensor_ptr acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne,
position_nb, GGML_MAX_DIMS);
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// sin/cos // sin/cos
ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float)); ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float));
void * sin_buffer = sin_allocator.get(); void * sin_buffer = sin_allocator.get();
acl_tensor_ptr acl_sin_tensor = acl_tensor_ptr acl_sin_tensor =
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get()); aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get());
ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float)); ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float));
void * cos_buffer = cos_allocator.get(); void * cos_buffer = cos_allocator.get();
acl_tensor_ptr acl_cos_tensor = acl_tensor_ptr acl_cos_tensor =
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get()); aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get());
if (ext_factor != 0) { if (ext_factor != 0) {
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale); attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
} }
// attn_factor // Step 5: multiply by attn_factor
if (attn_factor != 1) { if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true); aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true);
aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true); aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
} }
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 }; int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS]; size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float); sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
@ -2430,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat // Step 6: repeat
if (is_neox) { if (is_neox) {
// [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn]
int64_t repeatsArray[] = { 1, 1, 1, 2 }; int64_t repeatsArray[] = { 1, 1, 1, 2 };
aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray); aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray);
aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray); aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray);
@ -2439,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
int64_t num_repeats = 2; int64_t num_repeats = 2;
int64_t dim = 3; int64_t dim = 3;
int64_t output_size = theta_scale_length * num_repeats; int64_t output_size = theta_scale_length * num_repeats;
// [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn]
aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size);
aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size);
} }
// Other layers use cache except first layer. // Update cached value.
ctx.rope_cache.cached = true; ctx.rope_cache.cached = true;
ctx.rope_cache.ext_factor = ext_factor; ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox,
ctx.rope_cache.theta_scale = theta_scale; indep_sects, mrope_used, is_imrope, sections);
ctx.rope_cache.freq_scale = freq_scale;
ctx.rope_cache.attn_factor = attn_factor;
ctx.rope_cache.is_neox = is_neox;
} }
#ifdef __cplusplus #ifdef __cplusplus
@ -2475,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// param // param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
// const int n_past = ((int32_t *) dst->op_params)[0]; // const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
@ -2489,6 +2660,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
// TODO: n_dims <= ne0 // TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0); GGML_ASSERT(n_dims == ne0);
@ -2499,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
float corr_dims[2]; float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (mrope_used) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
if (is_imrope || mrope_used) {
is_neox = true;
}
// init ctx.rope_cos/rope_sin cache // init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox); aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 }; int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS]; size_t sin_reshape_nb[GGML_MAX_DIMS];
@ -2658,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
return; return;
#endif #endif
// ggml_mode = 0 --> aclnn_model = 1 int64_t acl_mode = is_neox ? 0 : 1;
int64_t acl_mode = mode == 0 ? 1 : mode;
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:

View File

@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache {
struct ggml_cann_rope_cache { struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() { ~ggml_cann_rope_cache() {
if (theta_scale_cache != nullptr) { if (theta_scale_cache) {
ACL_CHECK(aclrtFree(theta_scale_cache)); ACL_CHECK(aclrtFree(theta_scale_cache));
} }
if (sin_cache != nullptr) { if (sin_cache) {
ACL_CHECK(aclrtFree(sin_cache)); ACL_CHECK(aclrtFree(sin_cache));
} }
if (cos_cache != nullptr) { if (cos_cache) {
ACL_CHECK(aclrtFree(cos_cache)); ACL_CHECK(aclrtFree(cos_cache));
} }
if (position_select_index) {
ACL_CHECK(aclrtFree(position_select_index));
}
if (theta_scale_exp_host) {
free(theta_scale_exp_host);
}
if(position_select_index_host) {
free(position_select_index_host);
}
} }
bool equal(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
}
void set(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
this->theta_scale_length = theta_scale_length;
this->position_length = position_length;
this->ext_factor = ext_factor;
this->theta_scale = theta_scale;
this->freq_scale = freq_scale;
this->attn_factor = attn_factor;
this->is_neox = is_neox;
this->indep_sects = indep_sects;
this->mrope_used = mrope_used;
this->is_imrope = is_imrope;
this->sections[0] = sections[0];
this->sections[1] = sections[1];
this->sections[2] = sections[2];
this->sections[3] = sections[3];
}
// memory cache, prepare before inferencing.
void * theta_scale_cache = nullptr; void * theta_scale_cache = nullptr;
int64_t theta_scale_length = 0; float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
// sin/cos cache, used only to accelerate first layer on each device // sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr; void * sin_cache = nullptr;
void * cos_cache = nullptr; void * cos_cache = nullptr;
int64_t position_length = 0;
// Properties to check before reusing the sincos cache // Properties to check before reusing the sincos cache
int64_t theta_scale_length = 0;
int64_t position_length = 0;
bool cached = false; bool cached = false;
float ext_factor = 0.0f; float ext_factor = 0.0f;
float theta_scale = 0.0f; float theta_scale = 0.0f;
float freq_scale = 0.0f; float freq_scale = 0.0f;
float attn_factor = 0.0f; float attn_factor = 0.0f;
bool is_neox = false; bool is_neox = false;
bool indep_sects = false;
bool mrope_used = false;
int sections[4] = { 0, 0, 0, 0 };
bool is_imrope = false;
}; };
struct ggml_cann_tensor_cache { struct ggml_cann_tensor_cache {

View File

@ -2480,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
return false; return false;
} }
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
if (op->src[0]->ne[0] > 896) { if (op->src[0]->ne[0] > 896) {
return false; return false;
} }

View File

@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
include(CheckCXXSourceCompiles) include(CheckCXXSourceCompiles)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}") string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}")
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}")
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
set(ARM_FEATURE "HAVE_${feature}") set(ARM_FEATURE "HAVE_${feature}")
check_cxx_source_compiles( check_cxx_source_compiles(

View File

@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_argsort(params, tensor); ggml_compute_forward_argsort(params, tensor);
} break; } break;
case GGML_OP_TOP_K:
{
ggml_compute_forward_top_k(params, tensor);
} break;
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
{ {
ggml_compute_forward_leaky_relu(params, tensor); ggml_compute_forward_leaky_relu(params, tensor);
@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
} break; } break;
case GGML_OP_TOP_K:
{
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
} break;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
const int64_t ne10 = node->src[1]->ne[0]; // DK const int64_t ne10 = node->src[1]->ne[0]; // DK

View File

@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
// ggml_compute_forward_argsort // ggml_compute_forward_argsort
template<enum ggml_sort_order order> template<enum ggml_sort_order order>
struct argsort_cmp { struct cmp_argsort {
const float * data; const float * data;
bool operator()(int32_t a, int32_t b) const { bool operator()(int32_t a, int32_t b) const {
if constexpr (order == GGML_SORT_ORDER_ASC) { if constexpr (order == GGML_SORT_ORDER_ASC) {
@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
switch (order) { switch (order) {
case GGML_SORT_ORDER_ASC: case GGML_SORT_ORDER_ASC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data}); std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
break; break;
case GGML_SORT_ORDER_DESC: case GGML_SORT_ORDER_DESC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data}); std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
break; break;
default: default:
@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
} }
} }
// ggml_compute_forward_top_k
struct cmp_top_k {
const float * data;
bool operator()(int32_t a, int32_t b) const {
return data[a] > data[b];
}
};
static void ggml_compute_forward_top_k_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb0 == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_nrows(src0);
const int top_k = ne0;
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
for (int64_t i = ith; i < nr; i += nth) {
const float * src_data = (float *)((char *) src0->data + i*nb01);
for (int64_t j = 0; j < ne00; j++) {
tmp[j] = j;
}
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
std::copy(tmp, tmp + top_k, dst_data);
// emphasize that the order is not important
if (top_k > 1) {
std::swap(dst_data[0], dst_data[1]);
}
}
}
void ggml_compute_forward_top_k(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_top_k_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_flash_attn_ext // ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(

View File

@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -455,15 +455,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
} }
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD) #if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8; const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16; const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 8 * ggml_f16_epr; const int ggml_f16_step = 8 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np= (n & ~(ggml_f16_step - 1)); int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
@ -532,8 +531,8 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
hy = svmad_f16_x(pg, hx, vx, hy); hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy); svst1_f16(pg, (__fp16 *)(y + np2), hy);
} }
np = n;
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) #elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
const _Float16 scale = *(const _Float16*)(&s); const _Float16 scale = *(const _Float16*)(&s);
@ -566,7 +565,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
} }
#else #elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1)); const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@ -583,18 +582,14 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
} }
} }
#else
const int np = 0;
#endif
// leftovers // leftovers
for (int i = np; i < n; ++i) { for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
} }
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
} }
// xs and vs are byte strides of x and v // xs and vs are byte strides of x and v

View File

@ -437,10 +437,15 @@ namespace ggml_cuda_mma {
xi[0] = xs[0]; xi[0] = xs[0];
} }
#elif defined(AMD_WMMA_AVAILABLE) #elif defined(AMD_WMMA_AVAILABLE)
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
} else if constexpr (std::is_same_v<T, int>) {
if constexpr (I == 16 && J == 4) { if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x; int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0]; xi[0] = xs[0];
}else if constexpr (I == 16 && J == 8) { }else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x; int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
@ -448,9 +453,13 @@ namespace ggml_cuda_mma {
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0]; xi[1] = xs1[0];
}else{ }else{
NO_DEVICE_CODE; NO_DEVICE_CODE;
} }
} else {
NO_DEVICE_CODE;
}
#else #else
#pragma unroll #pragma unroll
for (int l = 0; l < t.ne; ++l) { for (int l = 0; l < t.ne; ++l) {

View File

@ -3701,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const size_t nbs_ids = mmq_x*sizeof(int); const size_t nbs_ids = mmq_x*sizeof(int);
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
} }

View File

@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
return res; return res;
} }
// note: reuse the argsort kernel for top_k
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_TOP_K);
char base[256];
char name[256];
// note: the top_k kernel is always descending order
ggml_sort_order order = GGML_SORT_ORDER_DESC;
const char * order_str = "undefined";
switch (order) {
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_TOP_K);
char base[256];
char name[256];
ggml_sort_order order = GGML_SORT_ORDER_DESC;
const char * order_str = "undefined";
switch (order) {
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib, ggml_metal_library_t lib,
const struct ggml_tensor * op, const struct ggml_tensor * op,

View File

@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:

View File

@ -832,14 +832,19 @@ typedef struct {
} ggml_metal_kargs_leaky_relu; } ggml_metal_kargs_leaky_relu;
typedef struct { typedef struct {
int64_t ne00; int32_t ne00;
int64_t ne01; int32_t ne01;
int64_t ne02; int32_t ne02;
int64_t ne03; int32_t ne03;
uint64_t nb00; uint64_t nb00;
uint64_t nb01; uint64_t nb01;
uint64_t nb02; uint64_t nb02;
uint64_t nb03; uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
int32_t top_k;
} ggml_metal_kargs_argsort; } ggml_metal_kargs_argsort;
typedef struct { typedef struct {
@ -851,6 +856,11 @@ typedef struct {
uint64_t nb01; uint64_t nb01;
uint64_t nb02; uint64_t nb02;
uint64_t nb03; uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
int32_t top_k;
int32_t len; int32_t len;
} ggml_metal_kargs_argsort_merge; } ggml_metal_kargs_argsort_merge;

View File

@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{ {
n_fuse = ggml_metal_op_argsort(ctx, idx); n_fuse = ggml_metal_op_argsort(ctx, idx);
} break; } break;
case GGML_OP_TOP_K:
{
n_fuse = ggml_metal_op_top_k(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
{ {
n_fuse = ggml_metal_op_leaky_relu(ctx, idx); n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
@ -3686,6 +3690,11 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
/*.nb01 =*/ nb01, /*.nb01 =*/ nb01,
/*.nb02 =*/ nb02, /*.nb02 =*/ nb02,
/*.nb03 =*/ nb03, /*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ nth,
}; };
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx); ggml_metal_op_concurrency_reset(ctx);
ggml_metal_kargs_argsort_merge args_merge = { ggml_metal_kargs_argsort_merge args_merge = {
.ne00 = ne00, /*.ne00 =*/ ne00,
.ne01 = ne01, /*.ne01 =*/ ne01,
.ne02 = ne02, /*.ne02 =*/ ne02,
.ne03 = ne03, /*.ne03 =*/ ne03,
.nb00 = nb00, /*.nb00 =*/ nb00,
.nb01 = nb01, /*.nb01 =*/ nb01,
.nb02 = nb02, /*.nb02 =*/ nb02,
.nb03 = nb03, /*.nb03 =*/ nb03,
.len = len, /*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ ne00,
/*.len =*/ len,
}; };
// merges per row // merges per row
@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
return 1; return 1;
} }
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
// bitonic sort requires the number of elements to be power of 2
int nth = 1;
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
// blocks per row
const int npr = (ne00 + nth - 1)/nth;
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
std::swap(bid_dst, bid_tmp);
}
const int top_k = ne0;
ggml_metal_kargs_argsort args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
};
if (npr > 1) {
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
}
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
int len = args.top_k;
while (len < args.ne0) {
ggml_metal_op_concurrency_reset(ctx);
// merges per row
const int nm = (args.ne0 + 2*len - 1) / (2*len);
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
ggml_metal_kargs_argsort_merge args_merge = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ args.ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
/*.len =*/ len,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
std::swap(bid_dst, bid_tmp);
len <<= 1;
}
return 1;
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx); ggml_tensor * op = ctx->node(idx);

View File

@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);

View File

@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
{ {
res *= 2; res *= 2;
} break; } break;
case GGML_OP_TOP_K:
{
res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
} break;
default: default:
break; break;
} }

View File

@ -4670,9 +4670,10 @@ kernel void kernel_argsort_f32_i32(
ushort3 ntg[[threads_per_threadgroup]]) { ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort // bitonic sort
const int col = tpitg[0]; const int col = tpitg[0];
const int ib = tgpig[0] / args.ne01;
const int i00 = (tgpig[0]/args.ne01)*ntg.x; const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01; const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1]; const int i02 = tgpig[1];
const int i03 = tgpig[2]; const int i03 = tgpig[2];
@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32(
} }
} }
const int64_t i0 = ib*args.top_k;
// copy the result to dst without the padding // copy the result to dst without the padding
if (i00 + col < args.ne00) { if (i0 + col < args.ne0 && col < args.top_k) {
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03; dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col]; dst[col] = shmem_i32[col];
} }
@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32(
const int start = im * (2 * args.len); const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1; const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start device const int32_t * tmp0 = tmp + start
+ i01*args.ne00 + i01*args.ne0
+ i02*args.ne00*args.ne01 + i02*args.ne0*args.ne01
+ i03*args.ne00*args.ne01*args.ne02; + i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len; device const int32_t * tmp1 = tmp0 + args.len;
dst += start dst += start
+ i01*args.ne00 + i01*args.top_k
+ i02*args.ne00*args.ne01 + i02*args.top_k*args.ne01
+ i03*args.ne00*args.ne01*args.ne02; + i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0 device const float * src0_row = (device const float *)(src0
+ args.nb01*i01 + args.nb01*i01
@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32(
const int chunk = (total + ntg.x - 1) / ntg.x; const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk; const int k0 = tpitg.x * chunk;
const int k1 = min(k0 + chunk, total); const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
if (k0 >= args.top_k) {
return;
}
if (k0 >= total) { if (k0 >= total) {
return; return;

View File

@ -409,6 +409,7 @@ enum shader_reduction_mode {
// argsort pipelines for up to 1<<10 invocations per workgroup // argsort pipelines for up to 1<<10 invocations per workgroup
static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t num_topk_moe_pipelines = 10; static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr uint32_t num_topk_pipelines = 11;
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
@ -515,6 +516,7 @@ struct vk_device_struct {
bool single_queue; bool single_queue;
bool support_async; bool support_async;
uint32_t subgroup_size; uint32_t subgroup_size;
uint32_t subgroup_size_log2;
uint32_t shader_core_count; uint32_t shader_core_count;
bool uma; bool uma;
bool prefer_host_memory; bool prefer_host_memory;
@ -704,7 +706,9 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@ -1204,6 +1208,15 @@ struct vk_op_argsort_push_constants {
uint32_t inner_end; uint32_t inner_end;
}; };
struct vk_op_topk_push_constants {
uint32_t orig_ncols;
uint32_t ncols_input;
uint32_t ncols_output;
uint32_t nrows;
uint32_t first_pass;
uint32_t last_pass;
};
struct vk_op_im2col_push_constants { struct vk_op_im2col_push_constants {
uint64_t dst_addr; uint64_t dst_addr;
uint32_t batch_offset; uint32_t offset_delta; uint32_t batch_offset; uint32_t offset_delta;
@ -3964,10 +3977,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true); ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
} }
for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
const uint32_t BLOCK_SIZE = 1u << i;
const uint32_t NCOLS_PADDED_LOG2 = i;
if (i <= device->max_workgroup_size_log2) {
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
sizeof(int) * device->subgroup_size +
2 * sizeof(int) +
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
} else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
}
}
}
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
#define IM2COL(bda) \ #define IM2COL(bda) \
@ -4333,6 +4365,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
device->subgroup_size = subgroup_props.subgroupSize; device->subgroup_size = subgroup_props.subgroupSize;
device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
if (sm_builtins) { if (sm_builtins) {
device->shader_core_count = sm_props.shaderSMCount; device->shader_core_count = sm_props.shaderSMCount;
@ -8457,6 +8490,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_sum_rows_f32; return ctx->device->pipeline_sum_rows_f32;
} }
return nullptr; return nullptr;
case GGML_OP_CUMSUM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cumsum_f32;
}
return nullptr;
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argmax_f32; return ctx->device->pipeline_argmax_f32;
@ -8821,6 +8859,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_CUMSUM:
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
{ {
@ -10134,6 +10173,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
} }
} }
static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
uint32_t ncols = src0->ne[0];
uint32_t nrows = ggml_nrows(src0);
uint32_t k = dst->ne[0];
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
// Reserve space for ivec2 per element, double buffered
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
const size_t x_sz = dbl_buf_size * 2;
uint32_t dbl_buf_index = 0;
if (ctx->prealloc_size_x < x_sz) {
ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
std::array<uint32_t, 3> elements;
elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = 1;
uint32_t num_elements = ncols;
// Each iteration reduces a workgroup's worth of elements down to the K
// largest elements. Repeat until we have the top K elements.
// Need to do at least one iteration to write out the results.
bool done_one_iter = false;
while (num_elements > k || !done_one_iter) {
done_one_iter = true;
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup
uint32_t max_pipeline = num_topk_pipelines - 3;
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
// require full subgroup
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
pipeline_idx = std::min(pipeline_idx, max_pipeline);
pipeline_idx = std::max(pipeline_idx, min_pipeline);
if (num_elements > (1u << pipeline_idx)) {
// If we could finish on this loop iteration (i.e. a single workgroup)
// then do so. It's better than the overhead of another pass.
for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
if (num_elements <= (1u << i)) {
pipeline_idx = i;
break;
}
}
}
vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
// If the device doesn't support a pipeline this large, use smaller
while (!pipeline) {
pipeline_idx--;
GGML_ASSERT(pipeline_idx >= min_pipeline);
pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
}
vk_op_topk_push_constants pc2 = pc;
pc2.ncols_input = num_elements;
// Number of elements remaining after this pass
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
vk_subbuffer src_buf;
vk_subbuffer dst_buf;
if (num_elements == ncols) {
pc2.first_pass = 1;
src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
} else {
src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
}
if (num_dst_elements == k) {
pc2.last_pass = 1;
dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
} else {
dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
}
elements[0] = num_elements;
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
num_elements = num_dst_elements;
dbl_buf_index ^= 1;
if (num_elements > k) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
ctx->prealloc_x_need_sync = true;
}
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
@ -10150,6 +10287,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
} }
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
} }
@ -11741,6 +11883,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ggml_vk_argsort(ctx, compute_ctx, src0, node); ggml_vk_argsort(ctx, compute_ctx, src0, node);
} }
break;
case GGML_OP_TOP_K:
ggml_vk_topk(ctx, compute_ctx, src0, node);
break; break;
case GGML_OP_SUM: case GGML_OP_SUM:
ggml_vk_sum(ctx, compute_ctx, src0, node); ggml_vk_sum(ctx, compute_ctx, src0, node);
@ -11749,6 +11895,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
ggml_vk_sum_rows(ctx, compute_ctx, src0, node); ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CUMSUM:
ggml_vk_cumsum(ctx, compute_ctx, src0, node);
break; break;
case GGML_OP_MEAN: case GGML_OP_MEAN:
ggml_vk_mean(ctx, compute_ctx, src0, node); ggml_vk_mean(ctx, compute_ctx, src0, node);
@ -13008,24 +13158,6 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
return false; return false;
}; };
// This function tries to reorder the graph to allow nodes to run in parallel.
// This helps with small batches, but for large batches its a slowdown, probably
// due to cache contention. So only reorder if the majority of nodes have few rows.
int num_small_nodes = 0;
int num_counted_nodes = 0;
for (int i = 0; i < graph->n_nodes; ++i) {
if (!is_empty(graph->nodes[i]) &&
graph->nodes[i]->op != GGML_OP_SET_ROWS) {
if (ggml_nrows(graph->nodes[i]) <= 8) {
num_small_nodes++;
}
num_counted_nodes++;
}
}
if (num_small_nodes < num_counted_nodes / 2) {
return;
}
std::vector<ggml_tensor *> new_order; std::vector<ggml_tensor *> new_order;
std::vector<bool> used(graph->n_nodes, false); std::vector<bool> used(graph->n_nodes, false);
std::set<ggml_tensor *> used_node_set; std::set<ggml_tensor *> used_node_set;
@ -13769,6 +13901,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return op->ne[0] <= (1 << device->max_workgroup_size_log2); return op->ne[0] <= (1 << device->max_workgroup_size_log2);
} }
} }
case GGML_OP_TOP_K:
{
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
// We could potentially support larger, using argsort to sort the
// whole thing. Not clear if this is needed.
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
if (min_pipeline >= num_topk_pipelines ||
!device->pipeline_topk_f32[min_pipeline]) {
return false;
}
}
return true;
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_ACC: case GGML_OP_ACC:
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
@ -13786,6 +13934,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_CUMSUM:
{
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
}
return false;
}
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
@ -14432,10 +14589,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ARGSORT) { } else if (tensor->op == GGML_OP_ARGSORT) {
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_TOP_K) {
tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
} else if (tensor->op == GGML_OP_SUM) { } else if (tensor->op == GGML_OP_SUM) {
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SUM_ROWS) { } else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CUMSUM) {
tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_MEAN) { } else if (tensor->op == GGML_OP_MEAN) {
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_ARGMAX) { } else if (tensor->op == GGML_OP_ARGMAX) {

View File

@ -0,0 +1,69 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
shared FLOAT_TYPE last_sum;
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
uint subgroup_id = tid / SUBGROUP_SIZE;
if (tid == 0) {
last_sum = 0;
}
uint col = tid;
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
for (int i = 0; i < num_iter; ++i) {
FLOAT_TYPE v = 0;
if (col < p.n_cols) {
v = FLOAT_TYPE(data_a[src_idx + col]);
}
v = subgroupInclusiveAdd(v);
// Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
partial[subgroup_id] = v;
}
barrier();
for (int j = 0; j < subgroup_id; ++j) {
v += partial[j];
}
v += last_sum;
barrier();
if (tid == BLOCK_SIZE - 1) {
last_sum = v;
}
if (col < p.n_cols) {
data_d[dst_idx + col] = D_TYPE(v);
}
col += BLOCK_SIZE;
}
}

View File

@ -1,6 +1,7 @@
#version 450 #version 450
#include "types.glsl" #include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_control_flow_attributes : enable
@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32; layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
shared FLOAT_TYPE tmp[BLOCK_SIZE]; shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() { void main() {

View File

@ -0,0 +1,25 @@
// vk_op_sum_rows_push_constants
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}

View File

@ -0,0 +1,113 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Input can either be the source (A) or intermediate values (S).
// Similarly, output can be either destination (D) or intermediate values (S).
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
layout (binding = 1) writeonly buffer D {int data_d[];};
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
layout (push_constant) uniform parameter {
uint orig_ncols;
uint ncols_input;
uint ncols_output;
uint nrows;
uint first_pass;
uint last_pass;
} p;
// pairs of (gid, value)
shared ivec2 dst_row[BLOCK_SIZE];
void topk(bool needs_bounds_check, const uint row) {
const int col = int(gl_LocalInvocationID.x);
// initialize indices
if (gl_GlobalInvocationID.x < p.ncols_input) {
if (p.first_pass != 0) {
const uint row_offset = row * p.ncols_input;
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
const uint row_offset = row * p.orig_ncols;
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
dst_row[col] = ivec2(p.orig_ncols, 0);
}
barrier();
if (p.ncols_output == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (col < s) {
ivec2 a = dst_row[col];
ivec2 b = dst_row[col + s];
if (a.x >= p.orig_ncols ||
b.x < p.orig_ncols && b.y > a.y) {
dst_row[col] = b;
}
}
barrier();
}
} else {
// bitonic sort on this group of elements
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1;
for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
const int ixj = int(col ^ j);
int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col;
ivec2 sh_idx_0 = dst_row[idx_0];
ivec2 sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
if ((idx_0_oob ||
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
dst_row[idx_0] = sh_idx_1;
dst_row[idx_1] = sh_idx_0;
}
barrier();
}
}
}
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
if (p.last_pass != 0) {
const uint row_offset = row * p.ncols_output;
data_d[row_offset + col] = dst_row[col].x;
} else {
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
data_t[row_offset + col] = dst_row[col];
}
}
}
void main() {
// Fast path for fully occupied workgroups
if ((p.ncols_input % BLOCK_SIZE) == 0) {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(false, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
} else {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(true, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
}

View File

@ -0,0 +1,199 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_debug_printf : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Input can either be the source (A) or intermediate values (S).
// Similarly, output can be either destination (D) or intermediate values (S).
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
layout (binding = 1) writeonly buffer D {int data_d[];};
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
layout (push_constant) uniform parameter {
uint orig_ncols;
uint ncols_input;
uint ncols_output;
uint nrows;
uint first_pass;
uint last_pass;
} p;
// pairs of (gid, value)
shared ivec2 dst_row[BLOCK_SIZE];
shared int counts[SUBGROUP_SIZE];
shared int sh_min_idx;
shared uint sh_total;
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
// Map float values to uint such that comparisons still work.
// Positive values set the high bit, negative values are inverted.
// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
uint f2ui(float x) {
uint y = floatBitsToUint(x);
if ((y & 0x80000000) != 0) {
y ^= ~0;
} else {
y |= 0x80000000;
}
return y;
}
void topk(const uint row) {
const int tid = int(gl_LocalInvocationID.x);
// initialize indices
if (gl_GlobalInvocationID.x < p.ncols_input) {
if (p.first_pass != 0) {
const uint row_offset = row * p.ncols_input;
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
const uint row_offset = row * p.orig_ncols;
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
}
barrier();
if (p.ncols_output == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (tid < s) {
ivec2 a = dst_row[tid];
ivec2 b = dst_row[tid + s];
if (a.x >= p.orig_ncols ||
b.x < p.orig_ncols && b.y > a.y) {
dst_row[tid] = b;
}
}
barrier();
}
} else {
// Do an N-ary search to find the K-th largest value.
// We remap the float values to be comparable as unsigned integers,
// and split the range into 2^N smaller ranges where N is the
// subgroup size. Count how many values are in each range, if the K-th
// largest value is in the middle of one of thee ranges then repeat
// and split again.
// Mask is the current set of bits we're searching. Shift is the LSB index.
int shift = 32 - SUBGROUP_SIZE_LOG2;
uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
// The current range.
uint range_min = 0;
uint range_max = 0xFF800000;
// How many are above the current range, and how many we need to find.
uint total = 0;
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
while (mask != 0) {
barrier();
// Initialize bucket counts to zero.
if (tid < SUBGROUP_SIZE) {
counts[tid] = 0;
}
barrier();
// Count how many values are in each bucket.
if (tid < p.ncols_input) {
float y = intBitsToFloat(dst_row[tid].y);
uint fy = f2ui(y);
if (fy >= range_min && fy < range_max) {
uint bucket = (fy & mask) >> shift;
atomicAdd(counts[bucket], 1);
}
}
barrier();
// On the first subgroup, do a scan to count (from the top down) how
// many elements are in the top N buckets. Find the index of the first
// that is over the limit. Copy it to the other invocations through
// shared memory.
if (tid < SUBGROUP_SIZE) {
uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
partial_sum = subgroupInclusiveAdd(partial_sum) + total;
uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
if (tid == t) {
sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
sh_total = partial_sum;
}
}
barrier();
int min_idx = sh_min_idx;
total = sh_total;
// Update the range, and break if we've found the K-th largest.
range_max = range_min + ((min_idx + 1) << shift);
range_min = range_min + (min_idx << shift);
if (total == p.ncols_output) {
break;
}
total -= counts[min_idx];
mask >>= SUBGROUP_SIZE_LOG2;
shift -= SUBGROUP_SIZE_LOG2;
if (shift < 0) {
shift = 0;
}
}
ivec2 v = dst_row[tid];
// We need to compact these values to the start of the dst_row array.
// Have each subgroup count how many items it'll store, so other
// subgroups can compute their base offset.
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
uvec4 b = subgroupBallot(top);
uint bit_count = subgroupBallotBitCount(b);
if ((tid % SUBGROUP_SIZE) == 0) {
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
}
barrier();
uint out_idx = 0;
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
if (i < tid / SUBGROUP_SIZE) {
out_idx += offset_partials[i];
}
}
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
if (top) {
// TODO: Copy directly to the output?
dst_row[out_idx + bit_count_ex] = v;
}
barrier();
}
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
if (p.last_pass != 0) {
const uint row_offset = row * p.ncols_output;
data_d[row_offset + tid] = dst_row[tid].x;
} else {
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
data_t[row_offset + tid] = dst_row[tid];
}
}
}
void main() {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}

View File

@ -913,9 +913,13 @@ void process_shaders() {
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
for (std::string dim_str : {"", "_3d"}) { for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) { for (bool bda : {false, true}) {

View File

@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ARANGE", "ARANGE",
"TIMESTEP_EMBEDDING", "TIMESTEP_EMBEDDING",
"ARGSORT", "ARGSORT",
"TOP_K",
"LEAKY_RELU", "LEAKY_RELU",
"TRI", "TRI",
"FILL", "FILL",
@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU", "GLU",
}; };
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"arange(start, stop, step)", "arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)", "timestep_embedding(timesteps, dim, max_period)",
"argsort(x)", "argsort(x)",
"top_k(x)",
"leaky_relu(x)", "leaky_relu(x)",
"tri(x)", "tri(x)",
"fill(x, c)", "fill(x, c)",
@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)", "glu(x)",
}; };
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll(
return result; return result;
} }
// ggml_arange
struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step) {
GGML_ASSERT(stop > start);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
ggml_set_op_params_f32(result, 0, start);
ggml_set_op_params_f32(result, 1, stop);
ggml_set_op_params_f32(result, 2, step);
result->op = GGML_OP_ARANGE;
return result;
}
// ggml_timestep_embedding // ggml_timestep_embedding
struct ggml_tensor * ggml_timestep_embedding( struct ggml_tensor * ggml_timestep_embedding(
@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort(
struct ggml_tensor * a, struct ggml_tensor * a,
enum ggml_sort_order order) { enum ggml_sort_order order) {
GGML_ASSERT(a->ne[0] <= INT32_MAX); GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
ggml_set_op_params_i32(result, 0, (int32_t) order); ggml_set_op_params_i32(result, 0, (int32_t) order);
@ -5149,6 +5130,24 @@ struct ggml_tensor * ggml_argsort(
return result; return result;
} }
// ggml_argsort_top_k
struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k) {
GGML_ASSERT(a->ne[0] >= k);
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
result = ggml_view_4d(ctx, result,
k, result->ne[1], result->ne[2], result->ne[3],
result->nb[1], result->nb[2], result->nb[3],
0);
return result;
}
// ggml_top_k // ggml_top_k
struct ggml_tensor * ggml_top_k( struct ggml_tensor * ggml_top_k(
@ -5157,12 +5156,32 @@ struct ggml_tensor * ggml_top_k(
int k) { int k) {
GGML_ASSERT(a->ne[0] >= k); GGML_ASSERT(a->ne[0] >= k);
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
result = ggml_view_4d(ctx, result, result->op = GGML_OP_TOP_K;
k, result->ne[1], result->ne[2], result->ne[3], result->src[0] = a;
result->nb[1], result->nb[2], result->nb[3],
0); return result;
}
// ggml_arange
struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step) {
GGML_ASSERT(stop > start);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
ggml_set_op_params_f32(result, 0, start);
ggml_set_op_params_f32(result, 1, stop);
ggml_set_op_params_f32(result, 2, step);
result->op = GGML_OP_ARANGE;
return result; return result;
} }

View File

@ -4,6 +4,7 @@ import logging
import os import os
import shutil import shutil
import struct import struct
import sys
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
@ -372,8 +373,10 @@ class GGUFWriter:
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None, raw_dtype: GGMLQuantizationType | None = None,
) -> None: ) -> None:
if self.endianess == GGUFEndian.BIG: if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
tensor.byteswap(inplace=True) (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
if self.use_temp_file and self.temp_file is None: if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
fp.seek(0) fp.seek(0)
@ -399,8 +402,10 @@ class GGUFWriter:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None assert self.fout is not None
if self.endianess == GGUFEndian.BIG: if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
tensor.byteswap(inplace=True) (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
file_id = -1 file_id = -1
for i, tensors in enumerate(self.tensors): for i, tensors in enumerate(self.tensors):

View File

@ -16,7 +16,7 @@ vendor = {
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h", "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h",
} }
for url, filename in vendor.items(): for url, filename in vendor.items():

View File

@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
// organize experts into n_expert_groups // organize experts into n_expert_groups
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens] ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens] group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
// get top n_group_used expert groups // get top n_group_used expert groups
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens] group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens] group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
cb(expert_groups, "ffn_moe_group_topk", il); cb(expert_groups, "ffn_moe_group_topk", il);
// mask out the other groups // mask out the other groups
@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
} }
// select experts // select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il); cb(selected_experts, "ffn_moe_topk", il);

View File

@ -39,6 +39,7 @@
#include <string_view> #include <string_view>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <unordered_map>
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
size_t nels = ggml_nelements(tensor); size_t nels = ggml_nelements(tensor);
@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) {
return mse_a_b / mse_a_0; return mse_a_b / mse_a_0;
} }
// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
static double jdst(const int32_t * a, const int32_t * b, size_t n) {
std::unordered_map<int32_t, size_t> set_a;
std::unordered_map<int32_t, size_t> set_b;
for (size_t i = 0; i < n; ++i) {
set_a[a[i]]++;
set_b[b[i]]++;
}
size_t diff = 0;
for (const auto & p : set_a) {
const int64_t na = p.second;
const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0;
diff += std::abs(na - nb);
}
for (const auto & p : set_b) {
if (set_a.find(p.first) == set_a.end()) {
diff += p.second;
}
}
return (double) diff / (2*n);
}
// maximum absolute asymmetry between a and b // maximum absolute asymmetry between a and b
// asymmetry: (a - b) / (a + b) // asymmetry: (a - b) / (a + b)
// This is more stable than relative error if one of the values fluctuates towards zero. // This is more stable than relative error if one of the values fluctuates towards zero.
@ -1051,6 +1080,14 @@ struct test_case {
return 1e-4; return 1e-4;
} }
virtual double max_err() {
return max_nmse_err();
}
virtual double err(const float * a, const float * b, size_t n) {
return nmse(a, b, n);
}
virtual float grad_eps() { virtual float grad_eps() {
return 1e-1f; return 1e-1f;
} }
@ -1257,16 +1294,16 @@ struct test_case {
// compare // compare
struct callback_userdata { struct callback_userdata {
bool ok; bool ok;
double max_err; test_case * tc;
ggml_backend_t backend1; ggml_backend_t backend1;
ggml_backend_t backend2; ggml_backend_t backend2;
}; };
callback_userdata ud { callback_userdata ud {
true, true,
max_nmse_err(), this,
backend1, backend1,
backend2 backend2,
}; };
auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
@ -1314,9 +1351,9 @@ struct test_case {
} }
} }
double err = nmse(f1.data(), f2.data(), f1.size()); double err = ud->tc->err(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) { if (err > ud->tc->max_err()) {
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err());
//for (int i = 0; i < (int) f1.size(); i++) { //for (int i = 0; i < (int) f1.size(); i++) {
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
//} //}
@ -4943,7 +4980,71 @@ struct test_argsort : public test_case {
} }
}; };
struct test_topk_moe: public test_case { // GGML_OP_TOP_K
struct test_top_k : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const int k;
std::string vars() override {
return VARS_TO_STR3(type, ne, k);
}
test_top_k(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {16, 10, 10, 10},
int k = 4)
: type(type), ne(ne), k(k) {}
double max_err() override {
return 0.0;
}
double err(const float * a, const float * b, size_t n) override {
std::vector<int32_t> ia(n);
std::vector<int32_t> ib(n);
double diff = 0.0f;
for (size_t i = 0; i < n; i++) {
ia[i] = (int32_t) a[i];
ib[i] = (int32_t) b[i];
// penalize the result if the data is not integer valued
diff += std::fabs(a[i] - ia[i]);
diff += std::fabs(b[i] - ib[i]);
}
return diff + jdst(ia.data(), ib.data(), n);
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * out = ggml_top_k(ctx, a, k);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// initialize with unique values to avoid ties
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<float> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
}
}
}
};
struct test_topk_moe : public test_case {
const std::array<int64_t, 4> ne; const std::array<int64_t, 4> ne;
const int n_expert_used; const int n_expert_used;
const bool with_norm; const bool with_norm;
@ -4976,7 +5077,7 @@ struct test_topk_moe: public test_case {
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
@ -7534,6 +7635,31 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
} }
for (int i = 0; i < 20; ++i) {
for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) {
if (k <= 1<<i) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
}
}
}
for (int k : {1, 2, 3, 7, 15}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k));
}
// exhaustive top_k tests
//for (int i = 1; i < 9999; ++i) {
// test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
//}
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
@ -7914,6 +8040,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
} }
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
for (auto k : {1, 10, 40}) {
for (auto nrows : {1, 16}) {
for (auto cols : {k, 1000, 65000, 200000}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
}
}
}
return test_cases; return test_cases;
} }

View File

@ -31,13 +31,16 @@ if (LLAMA_BUILD_BORINGSSL)
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}") message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")
include(FetchContent) set(BORINGSSL_ARGS
FetchContent_Declare(
boringssl
GIT_REPOSITORY ${BORINGSSL_GIT} GIT_REPOSITORY ${BORINGSSL_GIT}
GIT_TAG ${BORINGSSL_VERSION} GIT_TAG ${BORINGSSL_VERSION}
PATCH_COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_SOURCE_DIR}/patch-boringssl.cmake"
) )
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL)
endif()
include(FetchContent)
FetchContent_Declare(boringssl ${BORINGSSL_ARGS})
set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(SAVED_BUILD_TESTING ${BUILD_TESTING}) set(SAVED_BUILD_TESTING ${BUILD_TESTING})
@ -45,7 +48,15 @@ if (LLAMA_BUILD_BORINGSSL)
set(BUILD_SHARED_LIBS OFF) set(BUILD_SHARED_LIBS OFF)
set(BUILD_TESTING OFF) set(BUILD_TESTING OFF)
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
FetchContent_MakeAvailable(boringssl) FetchContent_MakeAvailable(boringssl)
else()
FetchContent_GetProperties(boringssl)
if(NOT boringssl_POPULATED)
FetchContent_Populate(boringssl)
add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()
endif()
set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS})
set(BUILD_TESTING ${SAVED_BUILD_TESTING}) set(BUILD_TESTING ${SAVED_BUILD_TESTING})

View File

@ -1087,22 +1087,30 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
// Fallback implementation using thread-based timeout for other Unix systems // Fallback implementation using thread-based timeout for other Unix systems
struct GetAddrInfoState { struct GetAddrInfoState {
~GetAddrInfoState() {
if (info) { freeaddrinfo(info); }
}
std::mutex mutex; std::mutex mutex;
std::condition_variable result_cv; std::condition_variable result_cv;
bool completed = false; bool completed = false;
int result = EAI_SYSTEM; int result = EAI_SYSTEM;
std::string node = node; std::string node;
std::string service = service; std::string service;
struct addrinfo hints = hints; struct addrinfo hints;
struct addrinfo *info = nullptr; struct addrinfo *info = nullptr;
}; };
// Allocate on the heap, so the resolver thread can keep using the data. // Allocate on the heap, so the resolver thread can keep using the data.
auto state = std::make_shared<GetAddrInfoState>(); auto state = std::make_shared<GetAddrInfoState>();
state->node = node;
state->service = service;
state->hints = *hints;
std::thread resolve_thread([=]() { std::thread resolve_thread([state]() {
auto thread_result = getaddrinfo( auto thread_result =
state->node.c_str(), state->service.c_str(), hints, &state->info); getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints,
&state->info);
std::lock_guard<std::mutex> lock(state->mutex); std::lock_guard<std::mutex> lock(state->mutex);
state->result = thread_result; state->result = thread_result;
@ -1120,6 +1128,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
// Operation completed within timeout // Operation completed within timeout
resolve_thread.join(); resolve_thread.join();
*res = state->info; *res = state->info;
state->info = nullptr; // Pass ownership to caller
return state->result; return state->result;
} else { } else {
// Timeout occurred // Timeout occurred
@ -4970,7 +4979,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection,
if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
// Prepare additional headers // Prepare additional headers
if (close_connection || req.get_header_value("Connection") == "close") { if (close_connection || req.get_header_value("Connection") == "close" ||
400 <= res.status) { // Don't leave connections open after errors
res.set_header("Connection", "close"); res.set_header("Connection", "close");
} else { } else {
std::string s = "timeout="; std::string s = "timeout=";
@ -5173,7 +5183,11 @@ bool Server::read_content_core(
size_t /*len*/) { return receiver(buf, n); }; size_t /*len*/) { return receiver(buf, n); };
} }
if (req.method == "DELETE" && !req.has_header("Content-Length")) { // RFC 7230 Section 3.3.3: If this is a request message and none of the above
// are true (no Transfer-Encoding and no Content-Length), then the message
// body length is zero (no message body is present).
if (!req.has_header("Content-Length") &&
!detail::is_chunked_transfer_encoding(req.headers)) {
return true; return true;
} }
@ -5681,8 +5695,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
// Check if the request URI doesn't exceed the limit // Check if the request URI doesn't exceed the limit
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
Headers dummy;
detail::read_headers(strm, dummy);
res.status = StatusCode::UriTooLong_414; res.status = StatusCode::UriTooLong_414;
output_error_log(Error::ExceedUriMaxLength, &req); output_error_log(Error::ExceedUriMaxLength, &req);
return write_response(strm, close_connection, req, res); return write_response(strm, close_connection, req, res);
@ -6666,11 +6678,13 @@ bool ClientImpl::write_request(Stream &strm, Request &req,
return true; return true;
} }
std::unique_ptr<Response> ClientImpl::send_with_content_provider( std::unique_ptr<Response>
ClientImpl::send_with_content_provider_and_receiver(
Request &req, const char *body, size_t content_length, Request &req, const char *body, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length, ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, Error &error) { const std::string &content_type, ContentReceiver content_receiver,
Error &error) {
if (!content_type.empty()) { req.set_header("Content-Type", content_type); } if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
#ifdef CPPHTTPLIB_ZLIB_SUPPORT #ifdef CPPHTTPLIB_ZLIB_SUPPORT
@ -6743,15 +6757,24 @@ std::unique_ptr<Response> ClientImpl::send_with_content_provider(
} }
} }
if (content_receiver) {
req.content_receiver =
[content_receiver](const char *data, size_t data_length,
size_t /*offset*/, size_t /*total_length*/) {
return content_receiver(data, data_length);
};
}
auto res = detail::make_unique<Response>(); auto res = detail::make_unique<Response>();
return send(req, *res, error) ? std::move(res) : nullptr; return send(req, *res, error) ? std::move(res) : nullptr;
} }
Result ClientImpl::send_with_content_provider( Result ClientImpl::send_with_content_provider_and_receiver(
const std::string &method, const std::string &path, const Headers &headers, const std::string &method, const std::string &path, const Headers &headers,
const char *body, size_t content_length, ContentProvider content_provider, const char *body, size_t content_length, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length, ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, UploadProgress progress) { const std::string &content_type, ContentReceiver content_receiver,
UploadProgress progress) {
Request req; Request req;
req.method = method; req.method = method;
req.headers = headers; req.headers = headers;
@ -6763,9 +6786,10 @@ Result ClientImpl::send_with_content_provider(
auto error = Error::Success; auto error = Error::Success;
auto res = send_with_content_provider( auto res = send_with_content_provider_and_receiver(
req, body, content_length, std::move(content_provider), req, body, content_length, std::move(content_provider),
std::move(content_provider_without_length), content_type, error); std::move(content_provider_without_length), content_type,
std::move(content_receiver), error);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, return Result{std::move(res), error, std::move(req.headers), last_ssl_error_,
@ -7094,6 +7118,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length,
content_type, progress); content_type, progress);
} }
Result ClientImpl::Post(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Post(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Post(const std::string &path, Result ClientImpl::Post(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -7102,6 +7135,15 @@ Result ClientImpl::Post(const std::string &path,
progress); progress);
} }
Result ClientImpl::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Post(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
const Params &params) { const Params &params) {
auto query = detail::params_to_query_str(params); auto query = detail::params_to_query_str(params);
@ -7142,17 +7184,18 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
const char *body, size_t content_length, const char *body, size_t content_length,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("POST", path, headers, body, content_length, return send_with_content_provider_and_receiver(
nullptr, nullptr, content_type, progress); "POST", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
const std::string &body, const std::string &body,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("POST", path, headers, body.data(), return send_with_content_provider_and_receiver(
body.size(), nullptr, nullptr, content_type, "POST", path, headers, body.data(), body.size(), nullptr, nullptr,
progress); content_type, nullptr, progress);
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7160,18 +7203,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProvider content_provider, ContentProvider content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("POST", path, headers, nullptr, return send_with_content_provider_and_receiver(
content_length, std::move(content_provider), "POST", path, headers, nullptr, content_length,
nullptr, content_type, progress); std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), std::move(progress));
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, return send_with_content_provider_and_receiver(
std::move(content_provider), content_type, "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
progress); content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), std::move(progress));
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7181,10 +7246,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary(); const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type = const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary); detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider( return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr, "POST", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items), get_multipart_content_provider(boundary, items, provider_items),
content_type, progress); content_type, nullptr, progress);
} }
Result ClientImpl::Post(const std::string &path, const Headers &headers, Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7246,6 +7311,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length,
content_type, progress); content_type, progress);
} }
Result ClientImpl::Put(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Put(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, Result ClientImpl::Put(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -7254,6 +7328,15 @@ Result ClientImpl::Put(const std::string &path,
progress); progress);
} }
Result ClientImpl::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Put(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
const Params &params) { const Params &params) {
auto query = detail::params_to_query_str(params); auto query = detail::params_to_query_str(params);
@ -7294,17 +7377,18 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
const char *body, size_t content_length, const char *body, size_t content_length,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, body, content_length, return send_with_content_provider_and_receiver(
nullptr, nullptr, content_type, progress); "PUT", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
const std::string &body, const std::string &body,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, body.data(), return send_with_content_provider_and_receiver(
body.size(), nullptr, nullptr, content_type, "PUT", path, headers, body.data(), body.size(), nullptr, nullptr,
progress); content_type, nullptr, progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7312,18 +7396,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProvider content_provider, ContentProvider content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, nullptr, return send_with_content_provider_and_receiver(
content_length, std::move(content_provider), "PUT", path, headers, nullptr, content_length,
nullptr, content_type, progress); std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, return send_with_content_provider_and_receiver(
std::move(content_provider), content_type, "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
progress); content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7333,10 +7439,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary(); const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type = const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary); detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider( return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr, "PUT", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items), get_multipart_content_provider(boundary, items, provider_items),
content_type, progress); content_type, nullptr, progress);
} }
Result ClientImpl::Put(const std::string &path, const Headers &headers, Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7400,6 +7506,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length,
content_type, progress); content_type, progress);
} }
Result ClientImpl::Patch(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Patch(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, Result ClientImpl::Patch(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -7408,6 +7523,15 @@ Result ClientImpl::Patch(const std::string &path,
progress); progress);
} }
Result ClientImpl::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Patch(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const Params &params) { const Params &params) {
auto query = detail::params_to_query_str(params); auto query = detail::params_to_query_str(params);
@ -7448,18 +7572,18 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const char *body, size_t content_length, const char *body, size_t content_length,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, body, return send_with_content_provider_and_receiver(
content_length, nullptr, nullptr, "PATCH", path, headers, body, content_length, nullptr, nullptr,
content_type, progress); content_type, nullptr, progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const std::string &body, const std::string &body,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, body.data(), return send_with_content_provider_and_receiver(
body.size(), nullptr, nullptr, content_type, "PATCH", path, headers, body.data(), body.size(), nullptr, nullptr,
progress); content_type, nullptr, progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -7467,18 +7591,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProvider content_provider, ContentProvider content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, nullptr, return send_with_content_provider_and_receiver(
content_length, std::move(content_provider), "PATCH", path, headers, nullptr, content_length,
nullptr, content_type, progress); std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, return send_with_content_provider_and_receiver(
std::move(content_provider), content_type, "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
progress); content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -7488,10 +7634,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary(); const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type = const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary); detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider( return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr, "PATCH", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items), get_multipart_content_provider(boundary, items, provider_items),
content_type, progress); content_type, nullptr, progress);
} }
Result ClientImpl::Patch(const std::string &path, const Headers &headers, Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -8883,12 +9029,28 @@ Result Client::Post(const std::string &path, size_t content_length,
return cli_->Post(path, content_length, std::move(content_provider), return cli_->Post(path, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Post(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Post(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, Result Client::Post(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return cli_->Post(path, std::move(content_provider), content_type, progress); return cli_->Post(path, std::move(content_provider), content_type, progress);
} }
Result Client::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Post(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Headers &headers, Result Client::Post(const std::string &path, const Headers &headers,
size_t content_length, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
@ -8897,6 +9059,15 @@ Result Client::Post(const std::string &path, const Headers &headers,
return cli_->Post(path, headers, content_length, std::move(content_provider), return cli_->Post(path, headers, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Headers &headers, Result Client::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -8904,6 +9075,14 @@ Result Client::Post(const std::string &path, const Headers &headers,
return cli_->Post(path, headers, std::move(content_provider), content_type, return cli_->Post(path, headers, std::move(content_provider), content_type,
progress); progress);
} }
Result Client::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Params &params) { Result Client::Post(const std::string &path, const Params &params) {
return cli_->Post(path, params); return cli_->Post(path, params);
} }
@ -8938,8 +9117,8 @@ Result Client::Post(const std::string &path, const Headers &headers,
const std::string &content_type, const std::string &content_type,
ContentReceiver content_receiver, ContentReceiver content_receiver,
DownloadProgress progress) { DownloadProgress progress) {
return cli_->Post(path, headers, body, content_type, content_receiver, return cli_->Post(path, headers, body, content_type,
progress); std::move(content_receiver), progress);
} }
Result Client::Put(const std::string &path) { return cli_->Put(path); } Result Client::Put(const std::string &path) { return cli_->Put(path); }
@ -8976,12 +9155,28 @@ Result Client::Put(const std::string &path, size_t content_length,
return cli_->Put(path, content_length, std::move(content_provider), return cli_->Put(path, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Put(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, Result Client::Put(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return cli_->Put(path, std::move(content_provider), content_type, progress); return cli_->Put(path, std::move(content_provider), content_type, progress);
} }
Result Client::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Headers &headers, Result Client::Put(const std::string &path, const Headers &headers,
size_t content_length, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
@ -8990,6 +9185,15 @@ Result Client::Put(const std::string &path, const Headers &headers,
return cli_->Put(path, headers, content_length, std::move(content_provider), return cli_->Put(path, headers, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Headers &headers, Result Client::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -8997,6 +9201,14 @@ Result Client::Put(const std::string &path, const Headers &headers,
return cli_->Put(path, headers, std::move(content_provider), content_type, return cli_->Put(path, headers, std::move(content_provider), content_type,
progress); progress);
} }
Result Client::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Params &params) { Result Client::Put(const std::string &path, const Params &params) {
return cli_->Put(path, params); return cli_->Put(path, params);
} }
@ -9072,12 +9284,28 @@ Result Client::Patch(const std::string &path, size_t content_length,
return cli_->Patch(path, content_length, std::move(content_provider), return cli_->Patch(path, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Patch(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, Result Client::Patch(const std::string &path,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
UploadProgress progress) { UploadProgress progress) {
return cli_->Patch(path, std::move(content_provider), content_type, progress); return cli_->Patch(path, std::move(content_provider), content_type, progress);
} }
Result Client::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Headers &headers, Result Client::Patch(const std::string &path, const Headers &headers,
size_t content_length, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
@ -9086,6 +9314,15 @@ Result Client::Patch(const std::string &path, const Headers &headers,
return cli_->Patch(path, headers, content_length, std::move(content_provider), return cli_->Patch(path, headers, content_length, std::move(content_provider),
content_type, progress); content_type, progress);
} }
Result Client::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Headers &headers, Result Client::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider, ContentProviderWithoutLength content_provider,
const std::string &content_type, const std::string &content_type,
@ -9093,6 +9330,14 @@ Result Client::Patch(const std::string &path, const Headers &headers,
return cli_->Patch(path, headers, std::move(content_provider), content_type, return cli_->Patch(path, headers, std::move(content_provider), content_type,
progress); progress);
} }
Result Client::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Params &params) { Result Client::Patch(const std::string &path, const Params &params) {
return cli_->Patch(path, params); return cli_->Patch(path, params);
} }

View File

@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H #ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.27.0" #define CPPHTTPLIB_VERSION "0.28.0"
#define CPPHTTPLIB_VERSION_NUM "0x001B00" #define CPPHTTPLIB_VERSION_NUM "0x001C00"
/* /*
* Platform compatibility check * Platform compatibility check
@ -257,6 +257,7 @@ using socklen_t = int;
#include <netinet/in.h> #include <netinet/in.h>
#ifdef __linux__ #ifdef __linux__
#include <resolv.h> #include <resolv.h>
#undef _res // Undefine _res macro to avoid conflicts with user code (#2278)
#endif #endif
#include <csignal> #include <csignal>
#include <netinet/tcp.h> #include <netinet/tcp.h>
@ -1421,14 +1422,18 @@ public:
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Params &params); Result Post(const std::string &path, const Params &params);
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers);
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const Params &params); Result Post(const std::string &path, const Headers &headers, const Params &params);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1439,14 +1444,18 @@ public:
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Params &params); Result Put(const std::string &path, const Params &params);
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers);
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const Params &params); Result Put(const std::string &path, const Headers &headers, const Params &params);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1457,14 +1466,18 @@ public:
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Params &params); Result Patch(const std::string &path, const Params &params);
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const Params &params); Result Patch(const std::string &path, const Headers &headers, const Params &params);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1712,17 +1725,19 @@ private:
template <typename ClientType> void setup_redirect_client(ClientType &client); template <typename ClientType> void setup_redirect_client(ClientType &client);
bool handle_request(Stream &strm, Request &req, Response &res, bool handle_request(Stream &strm, Request &req, Response &res,
bool close_connection, Error &error); bool close_connection, Error &error);
std::unique_ptr<Response> send_with_content_provider( std::unique_ptr<Response> send_with_content_provider_and_receiver(
Request &req, const char *body, size_t content_length, Request &req, const char *body, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length, ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, Error &error); const std::string &content_type, ContentReceiver content_receiver,
Result send_with_content_provider( Error &error);
Result send_with_content_provider_and_receiver(
const std::string &method, const std::string &path, const std::string &method, const std::string &path,
const Headers &headers, const char *body, size_t content_length, const Headers &headers, const char *body, size_t content_length,
ContentProvider content_provider, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length, ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, UploadProgress progress); const std::string &content_type, ContentReceiver content_receiver,
UploadProgress progress);
ContentProviderWithoutLength get_multipart_content_provider( ContentProviderWithoutLength get_multipart_content_provider(
const std::string &boundary, const UploadFormDataItems &items, const std::string &boundary, const UploadFormDataItems &items,
const FormDataProviderItems &provider_items) const; const FormDataProviderItems &provider_items) const;
@ -1775,14 +1790,18 @@ public:
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Params &params); Result Post(const std::string &path, const Params &params);
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers);
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const Params &params); Result Post(const std::string &path, const Headers &headers, const Params &params);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1793,14 +1812,18 @@ public:
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Params &params); Result Put(const std::string &path, const Params &params);
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers);
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const Params &params); Result Put(const std::string &path, const Headers &headers, const Params &params);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1811,14 +1834,18 @@ public:
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Params &params); Result Patch(const std::string &path, const Params &params);
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers); Result Patch(const std::string &path, const Headers &headers);
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const Params &params); Result Patch(const std::string &path, const Headers &headers, const Params &params);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);

View File

@ -1,6 +0,0 @@
# Remove bssl
file(READ "CMakeLists.txt" content)
string(REPLACE "add_executable(bssl" "#add_executable(bssl" content "${content}")
string(REPLACE "target_link_libraries(bssl" "#target_link_libraries(bssl" content "${content}")
string(REPLACE "install(TARGETS bssl" "#install(TARGETS bssl" content "${content}")
file(WRITE "CMakeLists.txt" "${content}")