Merge remote-tracking branch 'upstream/master' into backend-sampling

This commit is contained in:
Daniel Bevenius 2025-11-26 17:52:29 +01:00
commit 7c2bfb352e
No known key found for this signature in database
20 changed files with 1475 additions and 391 deletions

View File

@ -2207,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx,
}
/**
* @brief Initializes and caches sine/cosine positional encoding values
* (used in RoPE, Rotary Position Embedding) for attention layers.
* @brief Initializes and caches all intermediate tensors required for RoPE
* (Rotary Position Embedding), including support for Yarn, mRoPE,
* i-mRoPE, Neox repeat strategy, independent sectors, frequency factors
* and multi-section rotary groups.
*
* This function computes and caches the sin/cos values of
* θ = position * theta_scale for RoPE encoding. The cache is shared
* across attention layers, and only the first attention layer will
* trigger initialization. The cache includes repeated sin/cos values
* with different repeat methods depending on the @param is_neox flag.
* This function computes and caches the per-dimension θ coefficients used for
* Q/K rotary embedding. The cache is shared across layers, and recomputed only
* when any dependent parameter changes.
*
* Steps performed by this function:
* 1. Identify whether the target tensor belongs to Q/K in attention
* and restrict computation to the first layer only.
* 2. Initialize the theta scale array (arange power freq scaling).
* 3. Allocate sin/cos caches if the max prompt length increases.
* 4. Compute θ = position * theta_scale.
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled.
* The function now supports:
* - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor)
* - Per-dimension independent sector exponent rules (indep_sects + sections[])
* - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope)
* - Frequency factor division (src2)
* - Neox / normal repeat expansion modes
*
* @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation
* depends on the RoPE values (usually Qcur/Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values.
* @param freq_scale Frequency scaling factor, applied to theta scale.
* @param attn_factor Attention scaling factor, applied to sin/cos.
* @param is_neox Whether to use Neox-style repeat strategy
* (dim expansion vs repeat_interleave).
* @param ctx CANN backend context, containing memory pool,
* cached buffers, and runtime stream.
* @param dst Destination ggml_tensor whose computation
* depends on RoPE (typically Qcur or Kcur).
* @param corr_dims [low, high] Yarn correction range.
* @param ext_factor Yarn extrapolation strength. 0 = disabled.
* @param theta_scale Base multiplier for per-dimension θ exponent.
* @param freq_scale Global frequency scaling factor.
* @param attn_factor Optional scaling applied to sin/cos (if needed).
* @param is_neox Whether to use Neox-style dimension interleave.
* @param sections 4-way sector sizes for independent-section RoPE
* and multi-section mRoPE (t/h/w/e).
* @param mrope_used Whether to enable multi-section rotary embedding.
* @param is_imrope Whether to apply interleaved mRoPE rules.
* @param indep_sects Whether each dimension runs independent exponent
* resets based on @p sections.
*/
static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ggml_tensor * dst,
float * corr_dims,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox) {
static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
ggml_tensor * dst,
float * corr_dims,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
int sections[4],
bool mrope_used,
bool is_imrope,
bool indep_sects) {
ggml_tensor * src0 = dst->src[0]; // input
ggml_tensor * src1 = dst->src[1]; // position
ggml_tensor * src2 = dst->src[2]; // freq_factors
if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor &&
ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale &&
ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) {
int64_t theta_scale_length = src0->ne[0] / 2;
int64_t position_length = dst->ne[2];
// TODO: check theta_scale_length and position_length.
if (src2 == nullptr && ctx.rope_cache.cached &&
ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor,
is_neox, indep_sects, mrope_used, is_imrope, sections)) {
// use cache.
return;
}
int64_t theta_scale_length = src0->ne[0] / 2;
int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) };
// Step0: calculate tensor shape.
int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float),
theta_scale_length * sizeof(float) };
GGML_ASSERT(src1->type == GGML_TYPE_I32);
int64_t position_length = src1->ne[0];
int64_t position_ne[] = { 1, 1, position_length, 1 };
size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
int64_t position_ne[] = { 1, 1, position_length, 1 };
size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 };
size_t theta_nb[GGML_MAX_DIMS];
theta_nb[0] = sizeof(float);
int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 };
size_t cache_nb[GGML_MAX_DIMS];
cache_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1];
}
// theta_scale arange, [0,1,...,ne00/2 - 1]
// Step1: Compute the coefficient of theta. During the cache_init process, aside from
// (1) multiplying by the position,
// (2) dividing by freq_factors,
// (3) computing the sine and cosine,
// the other parameters used in the computation generally do not change in most scenarios.
// Therefore, we can first compute this part of the result and then cache it.
// Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor.
acl_tensor_ptr acl_theta_scale_tensor;
// cache theta scale
if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
// theta_scale and freq_scale should not change during the current token inference process,
// so we can directly use == here instead of comparing the absolute difference.
ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) {
ctx.rope_cache.theta_scale_length = theta_scale_length;
bool theta_scale_updated = false;
if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale ||
ctx.rope_cache.indep_sects != indep_sects) {
theta_scale_updated = true;
if (ctx.rope_cache.theta_scale_exp_host != nullptr) {
free(ctx.rope_cache.theta_scale_exp_host);
}
ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float));
GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr);
if (!indep_sects) {
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
} else {
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) {
ctx.rope_cache.theta_scale_exp_host[i] = 1;
continue;
}
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
}
if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
@ -2286,74 +2328,138 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, 1);
}
float start = 0;
float step = 1;
float stop = theta_scale_length;
float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements);
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
bool yarn_ramp_tensor_updated = false;
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0 &&
// TODO: check more parameter.
(ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
yarn_ramp_tensor_updated = true;
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0) {
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor =
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor =
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(),
acl_yarn_ramp_tensor.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());
aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());
// theta_interp = freq_scale * theta_extrap;
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
//
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
// cache freq_scale + (freq_scale - 1) * ramp_mix
float freq_scale_1 = freq_scale - 1;
acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
}
// theta_interp = freq_scale * theta_extrap;
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
//
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
// cache freq_scale + (freq_scale - 1) * ramp_mix
float freq_scale_1 = freq_scale - 1;
acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
}
// power
acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(),
acl_theta_scale_tensor.get());
if (ext_factor != 0) {
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
if (ext_factor != 0) {
if (theta_scale_updated || yarn_ramp_tensor_updated) {
theta_scale_updated = true;
aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get());
} else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
}
} else {
// use cache
if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) {
theta_scale_updated = true;
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
}
}
// Nothing changed, use cache.
if (!theta_scale_updated) {
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
}
// Step 1.4: prepare select index if mrope
acl_tensor_ptr position_select_index_tensor;
if (mrope_used) {
if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] ||
ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] ||
ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) {
if (ctx.rope_cache.position_select_index_host != nullptr) {
free(ctx.rope_cache.position_select_index_host);
}
ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int));
GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr);
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
// t,h,w,e
for (int i = 0; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (is_imrope) { // qwen3vl apply interleaved mrope
if (sector % 3 == 1 && sector < 3 * sections[1]) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
ctx.rope_cache.position_select_index_host[i] = 0;
} else {
ctx.rope_cache.position_select_index_host[i] = 3;
}
} else {
if (sector >= sections[0] && sector < sec_w) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector >= sec_w && sector < sec_e) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector >= sec_e) {
ctx.rope_cache.position_select_index_host[i] = 3;
} else {
ctx.rope_cache.position_select_index_host[i] = 0;
}
}
}
if (ctx.rope_cache.position_select_index != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
}
position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32,
sizeof(int), theta_scale_ne, theta_scale_nb, 1);
}
// Step2: divide by freq_factors
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
// freq_factors
if (src2) {
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
void * freq_fac_res_ptr = freq_fac_res_allocator.get();
@ -2366,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
}
// Step3: prepare position_tensor
acl_tensor_ptr acl_position_tensor;
ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool());
if (mrope_used) {
// Step3.1: select current position;
// position :
// pos1: [[0, 1 ,2 ,3 ],
// pos2: [4, 5 ,6 ,7 ],
// pos3: [8, 9 ,10,11],
// pos4: [12,13,14,15] ]
//
// select index = [0, 1, 2, 2, 1, 0]
//
// selected_tensor:
// [[0, 1 ,2 ,3 ],
// [4, 5 ,6 ,7 ],
// [8, 9 ,10,11],
// [8, 9 ,10,11],
// [4, 5 ,6 ,7 ],
// [0, 1 ,2 ,3 ]]
//
// transpose, from [seq_len:dims] to [dims:seq_len]
// [0, 4, 8 ,8 ,4, 0],
// [1, 5, 9, 9, 5, 1],
// [2, 6, 10,10,6 ,2],
// [3, 7, 11,11,7 3 ]]
//
// multipy by theta_scale_tensor
// [theta_scale^0, theta_scale^1, ..., theta_scale ^ n]
int64_t mrope_position_ne[] = { position_length, 4 };
size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) };
acl_tensor_ptr mrope_position =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
mrope_position_ne, mrope_position_nb, 2);
// selected position tensor's shape is a transpose of cache tensor.
int64_t selected_position_ne[] = { position_length, theta_scale_length };
size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) };
mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float));
void * mrope_position_buffer = mrope_position_acllocator.get();
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(),
acl_position_tensor.get());
// transpose
int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 };
size_t transposed_nb[GGML_MAX_DIMS];
transposed_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1];
}
std::swap(transposed_ne[0], transposed_ne[2]);
std::swap(transposed_nb[0], transposed_nb[2]);
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS);
} else {
// auto bcast.
acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
position_ne, position_nb, GGML_MAX_DIMS);
}
// Step4: multiply by the position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// Step5: calculate sin cos.
// init sin_repeat && cos_repeat, only to accelerate first layer on each device
if (position_length > ctx.rope_cache.position_length) {
ctx.rope_cache.position_length = position_length;
@ -2382,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
}
// position
acl_tensor_ptr acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne,
position_nb, GGML_MAX_DIMS);
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// sin/cos
ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float));
void * sin_buffer = sin_allocator.get();
acl_tensor_ptr acl_sin_tensor =
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get());
ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float));
void * cos_buffer = cos_allocator.get();
acl_tensor_ptr acl_cos_tensor =
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get());
if (ext_factor != 0) {
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
// attn_factor
// Step 5: multiply by attn_factor
if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true);
aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
}
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 };
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
@ -2430,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat
// Step 6: repeat
if (is_neox) {
// [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn]
int64_t repeatsArray[] = { 1, 1, 1, 2 };
aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray);
aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray);
@ -2439,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
int64_t num_repeats = 2;
int64_t dim = 3;
int64_t output_size = theta_scale_length * num_repeats;
// [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn]
aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size);
aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size);
}
// Other layers use cache except first layer.
ctx.rope_cache.cached = true;
ctx.rope_cache.ext_factor = ext_factor;
ctx.rope_cache.theta_scale = theta_scale;
ctx.rope_cache.freq_scale = freq_scale;
ctx.rope_cache.attn_factor = attn_factor;
ctx.rope_cache.is_neox = is_neox;
// Update cached value.
ctx.rope_cache.cached = true;
ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox,
indep_sects, mrope_used, is_imrope, sections);
}
#ifdef __cplusplus
@ -2475,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
// const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
@ -2483,12 +2654,13 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
GGML_TENSOR_UNARY_OP_LOCALS
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
// TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0);
@ -2499,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (mrope_used) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
if (is_imrope || mrope_used) {
is_neox = true;
}
// init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox);
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS];
@ -2658,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
return;
#endif
// ggml_mode = 0 --> aclnn_model = 1
int64_t acl_mode = mode == 0 ? 1 : mode;
int64_t acl_mode = is_neox ? 0 : 1;
switch (src0->type) {
case GGML_TYPE_F32:

View File

@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache {
struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() {
if (theta_scale_cache != nullptr) {
if (theta_scale_cache) {
ACL_CHECK(aclrtFree(theta_scale_cache));
}
if (sin_cache != nullptr) {
if (sin_cache) {
ACL_CHECK(aclrtFree(sin_cache));
}
if (cos_cache != nullptr) {
if (cos_cache) {
ACL_CHECK(aclrtFree(cos_cache));
}
if (position_select_index) {
ACL_CHECK(aclrtFree(position_select_index));
}
if (theta_scale_exp_host) {
free(theta_scale_exp_host);
}
if(position_select_index_host) {
free(position_select_index_host);
}
}
void * theta_scale_cache = nullptr;
int64_t theta_scale_length = 0;
bool equal(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
}
void set(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
this->theta_scale_length = theta_scale_length;
this->position_length = position_length;
this->ext_factor = ext_factor;
this->theta_scale = theta_scale;
this->freq_scale = freq_scale;
this->attn_factor = attn_factor;
this->is_neox = is_neox;
this->indep_sects = indep_sects;
this->mrope_used = mrope_used;
this->is_imrope = is_imrope;
this->sections[0] = sections[0];
this->sections[1] = sections[1];
this->sections[2] = sections[2];
this->sections[3] = sections[3];
}
// memory cache, prepare before inferencing.
void * theta_scale_cache = nullptr;
float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
// sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr;
void * cos_cache = nullptr;
int64_t position_length = 0;
void * sin_cache = nullptr;
void * cos_cache = nullptr;
// Properties to check before reusing the sincos cache
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
int64_t theta_scale_length = 0;
int64_t position_length = 0;
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
bool indep_sects = false;
bool mrope_used = false;
int sections[4] = { 0, 0, 0, 0 };
bool is_imrope = false;
};
struct ggml_cann_tensor_cache {

View File

@ -2480,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
return false;
}
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
if (op->src[0]->ne[0] > 896) {
return false;
}

View File

@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
include(CheckCXXSourceCompiles)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}")
string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}")
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}")
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
set(ARM_FEATURE "HAVE_${feature}")
check_cxx_source_compiles(

View File

@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
}
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 8 * ggml_f16_epr;
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 8 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np= (n & ~(ggml_f16_step - 1));
int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
GGML_F16x_VEC_STORE(y + k, ry, 0);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
np = n;
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
const int np = n;
_Float16 hv = (_Float16)v;
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e16m8(n - i);
vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
__riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
}
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
GGML_F16x_VEC_STORE(y + k, ry, 0);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
}
#else
// scalar
for (int i = 0; i < n; ++i) {
const int np = 0;
#endif
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
}
// xs and vs are byte strides of x and v

View File

@ -437,18 +437,27 @@ namespace ggml_cuda_mma {
xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
}else{
} else if constexpr (std::is_same_v<T, int>) {
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
}else{
NO_DEVICE_CODE;
}
} else {
NO_DEVICE_CODE;
}
#else

View File

@ -3701,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const size_t nbs_ids = mmq_x*sizeof(int);
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
}

View File

@ -409,6 +409,7 @@ enum shader_reduction_mode {
// argsort pipelines for up to 1<<10 invocations per workgroup
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr uint32_t num_topk_pipelines = 11;
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
@ -515,6 +516,7 @@ struct vk_device_struct {
bool single_queue;
bool support_async;
uint32_t subgroup_size;
uint32_t subgroup_size_log2;
uint32_t shader_core_count;
bool uma;
bool prefer_host_memory;
@ -704,7 +706,9 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@ -1204,6 +1208,15 @@ struct vk_op_argsort_push_constants {
uint32_t inner_end;
};
struct vk_op_topk_push_constants {
uint32_t orig_ncols;
uint32_t ncols_input;
uint32_t ncols_output;
uint32_t nrows;
uint32_t first_pass;
uint32_t last_pass;
};
struct vk_op_im2col_push_constants {
uint64_t dst_addr;
uint32_t batch_offset; uint32_t offset_delta;
@ -3964,10 +3977,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
}
for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
const uint32_t BLOCK_SIZE = 1u << i;
const uint32_t NCOLS_PADDED_LOG2 = i;
if (i <= device->max_workgroup_size_log2) {
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
sizeof(int) * device->subgroup_size +
2 * sizeof(int) +
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
} else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
}
}
}
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
#define IM2COL(bda) \
@ -4333,6 +4365,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
device->subgroup_size = subgroup_props.subgroupSize;
device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
if (sm_builtins) {
device->shader_core_count = sm_props.shaderSMCount;
@ -8457,6 +8490,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_sum_rows_f32;
}
return nullptr;
case GGML_OP_CUMSUM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cumsum_f32;
}
return nullptr;
case GGML_OP_ARGMAX:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argmax_f32;
@ -8821,6 +8859,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS:
case GGML_OP_CUMSUM:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
{
@ -10134,6 +10173,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
}
}
static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
uint32_t ncols = src0->ne[0];
uint32_t nrows = ggml_nrows(src0);
uint32_t k = dst->ne[0];
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
// Reserve space for ivec2 per element, double buffered
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
const size_t x_sz = dbl_buf_size * 2;
uint32_t dbl_buf_index = 0;
if (ctx->prealloc_size_x < x_sz) {
ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
std::array<uint32_t, 3> elements;
elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = 1;
uint32_t num_elements = ncols;
// Each iteration reduces a workgroup's worth of elements down to the K
// largest elements. Repeat until we have the top K elements.
// Need to do at least one iteration to write out the results.
bool done_one_iter = false;
while (num_elements > k || !done_one_iter) {
done_one_iter = true;
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup
uint32_t max_pipeline = num_topk_pipelines - 3;
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
// require full subgroup
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
pipeline_idx = std::min(pipeline_idx, max_pipeline);
pipeline_idx = std::max(pipeline_idx, min_pipeline);
if (num_elements > (1u << pipeline_idx)) {
// If we could finish on this loop iteration (i.e. a single workgroup)
// then do so. It's better than the overhead of another pass.
for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
if (num_elements <= (1u << i)) {
pipeline_idx = i;
break;
}
}
}
vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
// If the device doesn't support a pipeline this large, use smaller
while (!pipeline) {
pipeline_idx--;
GGML_ASSERT(pipeline_idx >= min_pipeline);
pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
}
vk_op_topk_push_constants pc2 = pc;
pc2.ncols_input = num_elements;
// Number of elements remaining after this pass
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
vk_subbuffer src_buf;
vk_subbuffer dst_buf;
if (num_elements == ncols) {
pc2.first_pass = 1;
src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
} else {
src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
}
if (num_dst_elements == k) {
pc2.last_pass = 1;
dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
} else {
dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
}
elements[0] = num_elements;
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
num_elements = num_dst_elements;
dbl_buf_index ^= 1;
if (num_elements > k) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
ctx->prealloc_x_need_sync = true;
}
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
@ -10150,6 +10287,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
}
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
}
@ -11741,6 +11883,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ggml_vk_argsort(ctx, compute_ctx, src0, node);
}
break;
case GGML_OP_TOP_K:
ggml_vk_topk(ctx, compute_ctx, src0, node);
break;
case GGML_OP_SUM:
ggml_vk_sum(ctx, compute_ctx, src0, node);
@ -11749,6 +11895,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SUM_ROWS:
ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CUMSUM:
ggml_vk_cumsum(ctx, compute_ctx, src0, node);
break;
case GGML_OP_MEAN:
ggml_vk_mean(ctx, compute_ctx, src0, node);
@ -13008,24 +13158,6 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
return false;
};
// This function tries to reorder the graph to allow nodes to run in parallel.
// This helps with small batches, but for large batches its a slowdown, probably
// due to cache contention. So only reorder if the majority of nodes have few rows.
int num_small_nodes = 0;
int num_counted_nodes = 0;
for (int i = 0; i < graph->n_nodes; ++i) {
if (!is_empty(graph->nodes[i]) &&
graph->nodes[i]->op != GGML_OP_SET_ROWS) {
if (ggml_nrows(graph->nodes[i]) <= 8) {
num_small_nodes++;
}
num_counted_nodes++;
}
}
if (num_small_nodes < num_counted_nodes / 2) {
return;
}
std::vector<ggml_tensor *> new_order;
std::vector<bool> used(graph->n_nodes, false);
std::set<ggml_tensor *> used_node_set;
@ -13769,6 +13901,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
}
}
case GGML_OP_TOP_K:
{
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
// We could potentially support larger, using argsort to sort the
// whole thing. Not clear if this is needed.
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
if (min_pipeline >= num_topk_pipelines ||
!device->pipeline_topk_f32[min_pipeline]) {
return false;
}
}
return true;
case GGML_OP_UPSCALE:
case GGML_OP_ACC:
case GGML_OP_CONCAT:
@ -13786,6 +13934,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_CUMSUM:
{
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
}
return false;
}
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@ -14432,10 +14589,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ARGSORT) {
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_TOP_K) {
tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
} else if (tensor->op == GGML_OP_SUM) {
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CUMSUM) {
tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_MEAN) {
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_ARGMAX) {

View File

@ -0,0 +1,69 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
shared FLOAT_TYPE last_sum;
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
uint subgroup_id = tid / SUBGROUP_SIZE;
if (tid == 0) {
last_sum = 0;
}
uint col = tid;
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
for (int i = 0; i < num_iter; ++i) {
FLOAT_TYPE v = 0;
if (col < p.n_cols) {
v = FLOAT_TYPE(data_a[src_idx + col]);
}
v = subgroupInclusiveAdd(v);
// Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
partial[subgroup_id] = v;
}
barrier();
for (int j = 0; j < subgroup_id; ++j) {
v += partial[j];
}
v += last_sum;
barrier();
if (tid == BLOCK_SIZE - 1) {
last_sum = v;
}
if (col < p.n_cols) {
data_d[dst_idx + col] = D_TYPE(v);
}
col += BLOCK_SIZE;
}
}

View File

@ -1,6 +1,7 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {

View File

@ -0,0 +1,25 @@
// vk_op_sum_rows_push_constants
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}

View File

@ -0,0 +1,113 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Input can either be the source (A) or intermediate values (S).
// Similarly, output can be either destination (D) or intermediate values (S).
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
layout (binding = 1) writeonly buffer D {int data_d[];};
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
layout (push_constant) uniform parameter {
uint orig_ncols;
uint ncols_input;
uint ncols_output;
uint nrows;
uint first_pass;
uint last_pass;
} p;
// pairs of (gid, value)
shared ivec2 dst_row[BLOCK_SIZE];
void topk(bool needs_bounds_check, const uint row) {
const int col = int(gl_LocalInvocationID.x);
// initialize indices
if (gl_GlobalInvocationID.x < p.ncols_input) {
if (p.first_pass != 0) {
const uint row_offset = row * p.ncols_input;
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
const uint row_offset = row * p.orig_ncols;
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
dst_row[col] = ivec2(p.orig_ncols, 0);
}
barrier();
if (p.ncols_output == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (col < s) {
ivec2 a = dst_row[col];
ivec2 b = dst_row[col + s];
if (a.x >= p.orig_ncols ||
b.x < p.orig_ncols && b.y > a.y) {
dst_row[col] = b;
}
}
barrier();
}
} else {
// bitonic sort on this group of elements
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1;
for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
const int ixj = int(col ^ j);
int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col;
ivec2 sh_idx_0 = dst_row[idx_0];
ivec2 sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
if ((idx_0_oob ||
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
dst_row[idx_0] = sh_idx_1;
dst_row[idx_1] = sh_idx_0;
}
barrier();
}
}
}
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
if (p.last_pass != 0) {
const uint row_offset = row * p.ncols_output;
data_d[row_offset + col] = dst_row[col].x;
} else {
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
data_t[row_offset + col] = dst_row[col];
}
}
}
void main() {
// Fast path for fully occupied workgroups
if ((p.ncols_input % BLOCK_SIZE) == 0) {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(false, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
} else {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(true, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
}

View File

@ -0,0 +1,199 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_debug_printf : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Input can either be the source (A) or intermediate values (S).
// Similarly, output can be either destination (D) or intermediate values (S).
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
layout (binding = 1) writeonly buffer D {int data_d[];};
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
layout (push_constant) uniform parameter {
uint orig_ncols;
uint ncols_input;
uint ncols_output;
uint nrows;
uint first_pass;
uint last_pass;
} p;
// pairs of (gid, value)
shared ivec2 dst_row[BLOCK_SIZE];
shared int counts[SUBGROUP_SIZE];
shared int sh_min_idx;
shared uint sh_total;
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
// Map float values to uint such that comparisons still work.
// Positive values set the high bit, negative values are inverted.
// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
uint f2ui(float x) {
uint y = floatBitsToUint(x);
if ((y & 0x80000000) != 0) {
y ^= ~0;
} else {
y |= 0x80000000;
}
return y;
}
void topk(const uint row) {
const int tid = int(gl_LocalInvocationID.x);
// initialize indices
if (gl_GlobalInvocationID.x < p.ncols_input) {
if (p.first_pass != 0) {
const uint row_offset = row * p.ncols_input;
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
const uint row_offset = row * p.orig_ncols;
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
}
barrier();
if (p.ncols_output == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (tid < s) {
ivec2 a = dst_row[tid];
ivec2 b = dst_row[tid + s];
if (a.x >= p.orig_ncols ||
b.x < p.orig_ncols && b.y > a.y) {
dst_row[tid] = b;
}
}
barrier();
}
} else {
// Do an N-ary search to find the K-th largest value.
// We remap the float values to be comparable as unsigned integers,
// and split the range into 2^N smaller ranges where N is the
// subgroup size. Count how many values are in each range, if the K-th
// largest value is in the middle of one of thee ranges then repeat
// and split again.
// Mask is the current set of bits we're searching. Shift is the LSB index.
int shift = 32 - SUBGROUP_SIZE_LOG2;
uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
// The current range.
uint range_min = 0;
uint range_max = 0xFF800000;
// How many are above the current range, and how many we need to find.
uint total = 0;
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
while (mask != 0) {
barrier();
// Initialize bucket counts to zero.
if (tid < SUBGROUP_SIZE) {
counts[tid] = 0;
}
barrier();
// Count how many values are in each bucket.
if (tid < p.ncols_input) {
float y = intBitsToFloat(dst_row[tid].y);
uint fy = f2ui(y);
if (fy >= range_min && fy < range_max) {
uint bucket = (fy & mask) >> shift;
atomicAdd(counts[bucket], 1);
}
}
barrier();
// On the first subgroup, do a scan to count (from the top down) how
// many elements are in the top N buckets. Find the index of the first
// that is over the limit. Copy it to the other invocations through
// shared memory.
if (tid < SUBGROUP_SIZE) {
uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
partial_sum = subgroupInclusiveAdd(partial_sum) + total;
uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
if (tid == t) {
sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
sh_total = partial_sum;
}
}
barrier();
int min_idx = sh_min_idx;
total = sh_total;
// Update the range, and break if we've found the K-th largest.
range_max = range_min + ((min_idx + 1) << shift);
range_min = range_min + (min_idx << shift);
if (total == p.ncols_output) {
break;
}
total -= counts[min_idx];
mask >>= SUBGROUP_SIZE_LOG2;
shift -= SUBGROUP_SIZE_LOG2;
if (shift < 0) {
shift = 0;
}
}
ivec2 v = dst_row[tid];
// We need to compact these values to the start of the dst_row array.
// Have each subgroup count how many items it'll store, so other
// subgroups can compute their base offset.
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
uvec4 b = subgroupBallot(top);
uint bit_count = subgroupBallotBitCount(b);
if ((tid % SUBGROUP_SIZE) == 0) {
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
}
barrier();
uint out_idx = 0;
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
if (i < tid / SUBGROUP_SIZE) {
out_idx += offset_partials[i];
}
}
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
if (top) {
// TODO: Copy directly to the output?
dst_row[out_idx + bit_count_ex] = v;
}
barrier();
}
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
if (p.last_pass != 0) {
const uint row_offset = row * p.ncols_output;
data_d[row_offset + tid] = dst_row[tid].x;
} else {
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
data_t[row_offset + tid] = dst_row[tid];
}
}
}
void main() {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}

View File

@ -913,9 +913,13 @@ void process_shaders() {
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) {

View File

@ -16,7 +16,7 @@ vendor = {
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h",
}
for url, filename in vendor.items():

View File

@ -7635,6 +7635,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
}
for (int i = 0; i < 20; ++i) {
for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) {
if (k <= 1<<i) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
}
}
}
for (int k : {1, 2, 3, 7, 15}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
@ -8034,12 +8042,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1}));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1}));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 1, 1, 1}, 40));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 1, 1, 1}, 1));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 400));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 40));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 1));
for (auto k : {1, 10, 40, 400}) {
for (auto nrows : {1, 16}) {
for (auto cols : {k, 1000, 65000, 200000}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
}
}
}
return test_cases;
}

View File

@ -31,13 +31,16 @@ if (LLAMA_BUILD_BORINGSSL)
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")
include(FetchContent)
FetchContent_Declare(
boringssl
set(BORINGSSL_ARGS
GIT_REPOSITORY ${BORINGSSL_GIT}
GIT_TAG ${BORINGSSL_VERSION}
PATCH_COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_SOURCE_DIR}/patch-boringssl.cmake"
)
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL)
endif()
include(FetchContent)
FetchContent_Declare(boringssl ${BORINGSSL_ARGS})
set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(SAVED_BUILD_TESTING ${BUILD_TESTING})
@ -45,7 +48,15 @@ if (LLAMA_BUILD_BORINGSSL)
set(BUILD_SHARED_LIBS OFF)
set(BUILD_TESTING OFF)
FetchContent_MakeAvailable(boringssl)
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
FetchContent_MakeAvailable(boringssl)
else()
FetchContent_GetProperties(boringssl)
if(NOT boringssl_POPULATED)
FetchContent_Populate(boringssl)
add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()
endif()
set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS})
set(BUILD_TESTING ${SAVED_BUILD_TESTING})

View File

@ -1087,22 +1087,30 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
// Fallback implementation using thread-based timeout for other Unix systems
struct GetAddrInfoState {
~GetAddrInfoState() {
if (info) { freeaddrinfo(info); }
}
std::mutex mutex;
std::condition_variable result_cv;
bool completed = false;
int result = EAI_SYSTEM;
std::string node = node;
std::string service = service;
struct addrinfo hints = hints;
std::string node;
std::string service;
struct addrinfo hints;
struct addrinfo *info = nullptr;
};
// Allocate on the heap, so the resolver thread can keep using the data.
auto state = std::make_shared<GetAddrInfoState>();
state->node = node;
state->service = service;
state->hints = *hints;
std::thread resolve_thread([=]() {
auto thread_result = getaddrinfo(
state->node.c_str(), state->service.c_str(), hints, &state->info);
std::thread resolve_thread([state]() {
auto thread_result =
getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints,
&state->info);
std::lock_guard<std::mutex> lock(state->mutex);
state->result = thread_result;
@ -1120,6 +1128,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
// Operation completed within timeout
resolve_thread.join();
*res = state->info;
state->info = nullptr; // Pass ownership to caller
return state->result;
} else {
// Timeout occurred
@ -4970,7 +4979,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection,
if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
// Prepare additional headers
if (close_connection || req.get_header_value("Connection") == "close") {
if (close_connection || req.get_header_value("Connection") == "close" ||
400 <= res.status) { // Don't leave connections open after errors
res.set_header("Connection", "close");
} else {
std::string s = "timeout=";
@ -5173,7 +5183,11 @@ bool Server::read_content_core(
size_t /*len*/) { return receiver(buf, n); };
}
if (req.method == "DELETE" && !req.has_header("Content-Length")) {
// RFC 7230 Section 3.3.3: If this is a request message and none of the above
// are true (no Transfer-Encoding and no Content-Length), then the message
// body length is zero (no message body is present).
if (!req.has_header("Content-Length") &&
!detail::is_chunked_transfer_encoding(req.headers)) {
return true;
}
@ -5681,8 +5695,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
// Check if the request URI doesn't exceed the limit
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
Headers dummy;
detail::read_headers(strm, dummy);
res.status = StatusCode::UriTooLong_414;
output_error_log(Error::ExceedUriMaxLength, &req);
return write_response(strm, close_connection, req, res);
@ -6666,11 +6678,13 @@ bool ClientImpl::write_request(Stream &strm, Request &req,
return true;
}
std::unique_ptr<Response> ClientImpl::send_with_content_provider(
std::unique_ptr<Response>
ClientImpl::send_with_content_provider_and_receiver(
Request &req, const char *body, size_t content_length,
ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, Error &error) {
const std::string &content_type, ContentReceiver content_receiver,
Error &error) {
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
@ -6743,15 +6757,24 @@ std::unique_ptr<Response> ClientImpl::send_with_content_provider(
}
}
if (content_receiver) {
req.content_receiver =
[content_receiver](const char *data, size_t data_length,
size_t /*offset*/, size_t /*total_length*/) {
return content_receiver(data, data_length);
};
}
auto res = detail::make_unique<Response>();
return send(req, *res, error) ? std::move(res) : nullptr;
}
Result ClientImpl::send_with_content_provider(
Result ClientImpl::send_with_content_provider_and_receiver(
const std::string &method, const std::string &path, const Headers &headers,
const char *body, size_t content_length, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, UploadProgress progress) {
const std::string &content_type, ContentReceiver content_receiver,
UploadProgress progress) {
Request req;
req.method = method;
req.headers = headers;
@ -6763,9 +6786,10 @@ Result ClientImpl::send_with_content_provider(
auto error = Error::Success;
auto res = send_with_content_provider(
auto res = send_with_content_provider_and_receiver(
req, body, content_length, std::move(content_provider),
std::move(content_provider_without_length), content_type, error);
std::move(content_provider_without_length), content_type,
std::move(content_receiver), error);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
return Result{std::move(res), error, std::move(req.headers), last_ssl_error_,
@ -7094,6 +7118,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length,
content_type, progress);
}
Result ClientImpl::Post(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Post(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -7102,6 +7135,15 @@ Result ClientImpl::Post(const std::string &path,
progress);
}
Result ClientImpl::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Post(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
const Params &params) {
auto query = detail::params_to_query_str(params);
@ -7142,17 +7184,18 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
const char *body, size_t content_length,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("POST", path, headers, body, content_length,
nullptr, nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"POST", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
const std::string &body,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("POST", path, headers, body.data(),
body.size(), nullptr, nullptr, content_type,
progress);
return send_with_content_provider_and_receiver(
"POST", path, headers, body.data(), body.size(), nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7160,18 +7203,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProvider content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("POST", path, headers, nullptr,
content_length, std::move(content_provider),
nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), std::move(progress));
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr,
std::move(content_provider), content_type,
progress);
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), std::move(progress));
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7181,10 +7246,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider(
return send_with_content_provider_and_receiver(
"POST", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items),
content_type, progress);
content_type, nullptr, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
@ -7246,6 +7311,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length,
content_type, progress);
}
Result ClientImpl::Put(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Put(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -7254,6 +7328,15 @@ Result ClientImpl::Put(const std::string &path,
progress);
}
Result ClientImpl::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Put(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
const Params &params) {
auto query = detail::params_to_query_str(params);
@ -7294,17 +7377,18 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
const char *body, size_t content_length,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, body, content_length,
nullptr, nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"PUT", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
const std::string &body,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, body.data(),
body.size(), nullptr, nullptr, content_type,
progress);
return send_with_content_provider_and_receiver(
"PUT", path, headers, body.data(), body.size(), nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7312,18 +7396,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProvider content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, nullptr,
content_length, std::move(content_provider),
nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr,
std::move(content_provider), content_type,
progress);
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7333,10 +7439,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider(
return send_with_content_provider_and_receiver(
"PUT", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items),
content_type, progress);
content_type, nullptr, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
@ -7400,6 +7506,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length,
content_type, progress);
}
Result ClientImpl::Patch(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Patch(path, Headers(), content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -7408,6 +7523,15 @@ Result ClientImpl::Patch(const std::string &path,
progress);
}
Result ClientImpl::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return Patch(path, Headers(), std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const Params &params) {
auto query = detail::params_to_query_str(params);
@ -7448,18 +7572,18 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const char *body, size_t content_length,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, body,
content_length, nullptr, nullptr,
content_type, progress);
return send_with_content_provider_and_receiver(
"PATCH", path, headers, body, content_length, nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const std::string &body,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, body.data(),
body.size(), nullptr, nullptr, content_type,
progress);
return send_with_content_provider_and_receiver(
"PATCH", path, headers, body.data(), body.size(), nullptr, nullptr,
content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -7467,18 +7591,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProvider content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, nullptr,
content_length, std::move(content_provider),
nullptr, content_type, progress);
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, content_length,
std::move(content_provider), nullptr, content_type,
std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr,
std::move(content_provider), content_type,
progress);
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -7488,10 +7634,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
return send_with_content_provider(
return send_with_content_provider_and_receiver(
"PATCH", path, headers, nullptr, 0, nullptr,
get_multipart_content_provider(boundary, items, provider_items),
content_type, progress);
content_type, nullptr, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
@ -8883,12 +9029,28 @@ Result Client::Post(const std::string &path, size_t content_length,
return cli_->Post(path, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Post(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Post(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return cli_->Post(path, std::move(content_provider), content_type, progress);
}
Result Client::Post(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Post(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
@ -8897,6 +9059,15 @@ Result Client::Post(const std::string &path, const Headers &headers,
return cli_->Post(path, headers, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Post(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -8904,6 +9075,14 @@ Result Client::Post(const std::string &path, const Headers &headers,
return cli_->Post(path, headers, std::move(content_provider), content_type,
progress);
}
Result Client::Post(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Post(const std::string &path, const Params &params) {
return cli_->Post(path, params);
}
@ -8938,8 +9117,8 @@ Result Client::Post(const std::string &path, const Headers &headers,
const std::string &content_type,
ContentReceiver content_receiver,
DownloadProgress progress) {
return cli_->Post(path, headers, body, content_type, content_receiver,
progress);
return cli_->Post(path, headers, body, content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path) { return cli_->Put(path); }
@ -8976,12 +9155,28 @@ Result Client::Put(const std::string &path, size_t content_length,
return cli_->Put(path, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Put(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return cli_->Put(path, std::move(content_provider), content_type, progress);
}
Result Client::Put(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
@ -8990,6 +9185,15 @@ Result Client::Put(const std::string &path, const Headers &headers,
return cli_->Put(path, headers, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Put(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -8997,6 +9201,14 @@ Result Client::Put(const std::string &path, const Headers &headers,
return cli_->Put(path, headers, std::move(content_provider), content_type,
progress);
}
Result Client::Put(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Put(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Put(const std::string &path, const Params &params) {
return cli_->Put(path, params);
}
@ -9072,12 +9284,28 @@ Result Client::Patch(const std::string &path, size_t content_length,
return cli_->Patch(path, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Patch(const std::string &path, size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
UploadProgress progress) {
return cli_->Patch(path, std::move(content_provider), content_type, progress);
}
Result Client::Patch(const std::string &path,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
@ -9086,6 +9314,15 @@ Result Client::Patch(const std::string &path, const Headers &headers,
return cli_->Patch(path, headers, content_length, std::move(content_provider),
content_type, progress);
}
Result Client::Patch(const std::string &path, const Headers &headers,
size_t content_length,
ContentProvider content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, headers, content_length, std::move(content_provider),
content_type, std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
@ -9093,6 +9330,14 @@ Result Client::Patch(const std::string &path, const Headers &headers,
return cli_->Patch(path, headers, std::move(content_provider), content_type,
progress);
}
Result Client::Patch(const std::string &path, const Headers &headers,
ContentProviderWithoutLength content_provider,
const std::string &content_type,
ContentReceiver content_receiver,
UploadProgress progress) {
return cli_->Patch(path, headers, std::move(content_provider), content_type,
std::move(content_receiver), progress);
}
Result Client::Patch(const std::string &path, const Params &params) {
return cli_->Patch(path, params);
}

View File

@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.27.0"
#define CPPHTTPLIB_VERSION_NUM "0x001B00"
#define CPPHTTPLIB_VERSION "0.28.0"
#define CPPHTTPLIB_VERSION_NUM "0x001C00"
/*
* Platform compatibility check
@ -257,6 +257,7 @@ using socklen_t = int;
#include <netinet/in.h>
#ifdef __linux__
#include <resolv.h>
#undef _res // Undefine _res macro to avoid conflicts with user code (#2278)
#endif
#include <csignal>
#include <netinet/tcp.h>
@ -1421,14 +1422,18 @@ public:
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Params &params);
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers);
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const Params &params);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1439,14 +1444,18 @@ public:
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Params &params);
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers);
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const Params &params);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1457,14 +1466,18 @@ public:
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Params &params);
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const Params &params);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1712,17 +1725,19 @@ private:
template <typename ClientType> void setup_redirect_client(ClientType &client);
bool handle_request(Stream &strm, Request &req, Response &res,
bool close_connection, Error &error);
std::unique_ptr<Response> send_with_content_provider(
std::unique_ptr<Response> send_with_content_provider_and_receiver(
Request &req, const char *body, size_t content_length,
ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, Error &error);
Result send_with_content_provider(
const std::string &content_type, ContentReceiver content_receiver,
Error &error);
Result send_with_content_provider_and_receiver(
const std::string &method, const std::string &path,
const Headers &headers, const char *body, size_t content_length,
ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const std::string &content_type, UploadProgress progress);
const std::string &content_type, ContentReceiver content_receiver,
UploadProgress progress);
ContentProviderWithoutLength get_multipart_content_provider(
const std::string &boundary, const UploadFormDataItems &items,
const FormDataProviderItems &provider_items) const;
@ -1775,14 +1790,18 @@ public:
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Params &params);
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers);
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const Params &params);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1793,14 +1812,18 @@ public:
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Params &params);
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers);
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const Params &params);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
@ -1811,14 +1834,18 @@ public:
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Params &params);
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers);
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const Params &params);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);

View File

@ -1,6 +0,0 @@
# Remove bssl
file(READ "CMakeLists.txt" content)
string(REPLACE "add_executable(bssl" "#add_executable(bssl" content "${content}")
string(REPLACE "target_link_libraries(bssl" "#target_link_libraries(bssl" content "${content}")
string(REPLACE "install(TARGETS bssl" "#install(TARGETS bssl" content "${content}")
file(WRITE "CMakeLists.txt" "${content}")