Merge branch 'master' into riscv
This commit is contained in:
commit
2786a97ef0
|
|
@ -10061,6 +10061,25 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||||
torch.uint8: np.uint8,
|
torch.uint8: np.uint8,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# only used when byteswapping data. Only correct size is needed
|
||||||
|
_dtype_byteswap_map: dict[torch.dtype, type] = {
|
||||||
|
torch.float64: np.float64,
|
||||||
|
torch.float32: np.float32,
|
||||||
|
torch.bfloat16: np.float16,
|
||||||
|
torch.float16: np.float16,
|
||||||
|
torch.int64: np.int64,
|
||||||
|
torch.uint64: np.uint64,
|
||||||
|
torch.int32: np.int32,
|
||||||
|
torch.uint32: np.uint32,
|
||||||
|
torch.int16: np.int16,
|
||||||
|
torch.uint16: np.uint16,
|
||||||
|
torch.int8: np.int8,
|
||||||
|
torch.uint8: np.uint8,
|
||||||
|
torch.bool: np.uint8,
|
||||||
|
torch.float8_e4m3fn: np.uint8,
|
||||||
|
torch.float8_e5m2: np.uint8,
|
||||||
|
}
|
||||||
|
|
||||||
# used for safetensors slices
|
# used for safetensors slices
|
||||||
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
|
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
|
||||||
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
|
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
|
||||||
|
|
@ -10104,8 +10123,14 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
|
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
|
||||||
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
|
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
|
||||||
|
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
|
||||||
|
if sys.byteorder == 'big':
|
||||||
|
# switch data back to big endian
|
||||||
|
tensor = tensor.view(dtype).byteswap(inplace=False)
|
||||||
|
return tensor
|
||||||
dtype = cls._dtype_str_map[tensor.dtype]
|
dtype = cls._dtype_str_map[tensor.dtype]
|
||||||
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
|
numpy_dtype = cls._dtype_byteswap_map[dtype]
|
||||||
|
return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape)
|
||||||
dtype = cls._dtype_str_map[t.dtype]
|
dtype = cls._dtype_str_map[t.dtype]
|
||||||
shape = t.shape
|
shape = t.shape
|
||||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
|
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
|
||||||
|
|
@ -10113,10 +10138,16 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
|
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
|
||||||
|
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
|
||||||
|
if sys.byteorder == 'big':
|
||||||
|
# switch data back to big endian
|
||||||
|
tensor = tensor.view(dtype).byteswap(inplace=False)
|
||||||
|
return tensor
|
||||||
dtype = cls._dtype_str_map[remote_tensor.dtype]
|
dtype = cls._dtype_str_map[remote_tensor.dtype]
|
||||||
|
numpy_dtype = cls._dtype_byteswap_map[dtype]
|
||||||
shape = remote_tensor.shape
|
shape = remote_tensor.shape
|
||||||
meta = cls.meta_with_dtype_and_shape(dtype, shape)
|
meta = cls.meta_with_dtype_and_shape(dtype, shape)
|
||||||
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
|
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape))
|
||||||
return cast(torch.Tensor, lazy)
|
return cast(torch.Tensor, lazy)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -530,6 +530,7 @@ extern "C" {
|
||||||
GGML_OP_ARANGE,
|
GGML_OP_ARANGE,
|
||||||
GGML_OP_TIMESTEP_EMBEDDING,
|
GGML_OP_TIMESTEP_EMBEDDING,
|
||||||
GGML_OP_ARGSORT,
|
GGML_OP_ARGSORT,
|
||||||
|
GGML_OP_TOP_K,
|
||||||
GGML_OP_LEAKY_RELU,
|
GGML_OP_LEAKY_RELU,
|
||||||
GGML_OP_TRI,
|
GGML_OP_TRI,
|
||||||
GGML_OP_FILL,
|
GGML_OP_FILL,
|
||||||
|
|
@ -2258,18 +2259,25 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_sort_order order);
|
enum ggml_sort_order order);
|
||||||
|
|
||||||
|
// similar to ggml_top_k but implemented as `argsort` + `view`
|
||||||
|
GGML_API struct ggml_tensor * ggml_argsort_top_k(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int k);
|
||||||
|
|
||||||
|
// top k elements per row
|
||||||
|
// note: the resulting top k indices are in no particular order
|
||||||
|
GGML_API struct ggml_tensor * ggml_top_k(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int k);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_arange(
|
GGML_API struct ggml_tensor * ggml_arange(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
float start,
|
float start,
|
||||||
float stop,
|
float stop,
|
||||||
float step);
|
float step);
|
||||||
|
|
||||||
// top k elements per row
|
|
||||||
GGML_API struct ggml_tensor * ggml_top_k(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
int k);
|
|
||||||
|
|
||||||
#define GGML_KQ_MASK_PAD 64
|
#define GGML_KQ_MASK_PAD 64
|
||||||
|
|
||||||
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
||||||
|
|
|
||||||
|
|
@ -2207,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Initializes and caches sine/cosine positional encoding values
|
* @brief Initializes and caches all intermediate tensors required for RoPE
|
||||||
* (used in RoPE, Rotary Position Embedding) for attention layers.
|
* (Rotary Position Embedding), including support for Yarn, mRoPE,
|
||||||
|
* i-mRoPE, Neox repeat strategy, independent sectors, frequency factors,
|
||||||
|
* and multi-section rotary groups.
|
||||||
*
|
*
|
||||||
* This function computes and caches the sin/cos values of
|
* This function computes and caches the per-dimension θ coefficients used for
|
||||||
* θ = position * theta_scale for RoPE encoding. The cache is shared
|
* Q/K rotary embedding. The cache is shared across layers, and recomputed only
|
||||||
* across attention layers, and only the first attention layer will
|
* when any dependent parameter changes.
|
||||||
* trigger initialization. The cache includes repeated sin/cos values
|
|
||||||
* with different repeat methods depending on the @param is_neox flag.
|
|
||||||
*
|
*
|
||||||
* Steps performed by this function:
|
* The function now supports:
|
||||||
* 1. Identify whether the target tensor belongs to Q/K in attention
|
* - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor)
|
||||||
* and restrict computation to the first layer only.
|
* - Per-dimension independent sector exponent rules (indep_sects + sections[])
|
||||||
* 2. Initialize the theta scale array (arange → power → freq scaling).
|
* - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope)
|
||||||
* 3. Allocate sin/cos caches if the max prompt length increases.
|
* - Frequency factor division (src2)
|
||||||
* 4. Compute θ = position * theta_scale.
|
* - Neox / normal repeat expansion modes
|
||||||
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
|
|
||||||
* 6. Expand sin/cos values by repeat or repeat_interleave depending
|
|
||||||
* on whether @param is_neox is enabled.
|
|
||||||
*
|
*
|
||||||
* @param ctx The CANN backend context, holding memory pool,
|
* @param ctx CANN backend context, containing memory pool,
|
||||||
* stream, and persistent buffers for rope init/cache.
|
* cached buffers, and runtime stream.
|
||||||
* @param dst The destination ggml_tensor whose computation
|
* @param dst Destination ggml_tensor whose computation
|
||||||
* depends on the RoPE values (usually Qcur/Kcur).
|
* depends on RoPE (typically Qcur or Kcur).
|
||||||
* @param theta_scale Scalar exponent base for computing theta scale values.
|
* @param corr_dims [low, high] Yarn correction range.
|
||||||
* @param freq_scale Frequency scaling factor, applied to theta scale.
|
* @param ext_factor Yarn extrapolation strength. 0 = disabled.
|
||||||
* @param attn_factor Attention scaling factor, applied to sin/cos.
|
* @param theta_scale Base multiplier for per-dimension θ exponent.
|
||||||
* @param is_neox Whether to use Neox-style repeat strategy
|
* @param freq_scale Global frequency scaling factor.
|
||||||
* (dim expansion vs repeat_interleave).
|
* @param attn_factor Optional scaling applied to sin/cos (if needed).
|
||||||
|
* @param is_neox Whether to use Neox-style dimension interleave.
|
||||||
|
* @param sections 4-way sector sizes for independent-section RoPE
|
||||||
|
* and multi-section mRoPE (t/h/w/e).
|
||||||
|
* @param mrope_used Whether to enable multi-section rotary embedding.
|
||||||
|
* @param is_imrope Whether to apply interleaved mRoPE rules.
|
||||||
|
* @param indep_sects Whether each dimension runs independent exponent
|
||||||
|
* resets based on @p sections.
|
||||||
*/
|
*/
|
||||||
static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
|
||||||
ggml_tensor * dst,
|
ggml_tensor * dst,
|
||||||
float * corr_dims,
|
float * corr_dims,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
float theta_scale,
|
float theta_scale,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float attn_factor,
|
float attn_factor,
|
||||||
bool is_neox) {
|
bool is_neox,
|
||||||
|
int sections[4],
|
||||||
|
bool mrope_used,
|
||||||
|
bool is_imrope,
|
||||||
|
bool indep_sects) {
|
||||||
ggml_tensor * src0 = dst->src[0]; // input
|
ggml_tensor * src0 = dst->src[0]; // input
|
||||||
ggml_tensor * src1 = dst->src[1]; // position
|
ggml_tensor * src1 = dst->src[1]; // position
|
||||||
ggml_tensor * src2 = dst->src[2]; // freq_factors
|
ggml_tensor * src2 = dst->src[2]; // freq_factors
|
||||||
|
|
||||||
if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor &&
|
int64_t theta_scale_length = src0->ne[0] / 2;
|
||||||
ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale &&
|
int64_t position_length = dst->ne[2];
|
||||||
ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) {
|
|
||||||
|
// TODO: check theta_scale_length and position_length.
|
||||||
|
if (src2 == nullptr && ctx.rope_cache.cached &&
|
||||||
|
ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor,
|
||||||
|
is_neox, indep_sects, mrope_used, is_imrope, sections)) {
|
||||||
// use cache.
|
// use cache.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t theta_scale_length = src0->ne[0] / 2;
|
// Step0: calculate tensor shape.
|
||||||
int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
|
int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
|
||||||
size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) };
|
size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float),
|
||||||
|
theta_scale_length * sizeof(float) };
|
||||||
|
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||||
int64_t position_length = src1->ne[0];
|
int64_t position_ne[] = { 1, 1, position_length, 1 };
|
||||||
int64_t position_ne[] = { 1, 1, position_length, 1 };
|
size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
|
||||||
size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
|
|
||||||
|
|
||||||
int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 };
|
int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 };
|
||||||
size_t theta_nb[GGML_MAX_DIMS];
|
size_t cache_nb[GGML_MAX_DIMS];
|
||||||
theta_nb[0] = sizeof(float);
|
cache_nb[0] = sizeof(float);
|
||||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||||
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
|
cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
// theta_scale arange, [0,1,...,ne00/2 - 1]
|
// Step1: Compute the coefficient of theta. During the cache_init process, aside from
|
||||||
|
// (1) multiplying by the position,
|
||||||
|
// (2) dividing by freq_factors,
|
||||||
|
// (3) computing the sine and cosine,
|
||||||
|
// the other parameters used in the computation generally do not change in most scenarios.
|
||||||
|
// Therefore, we can first compute this part of the result and then cache it.
|
||||||
|
|
||||||
|
// Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor.
|
||||||
acl_tensor_ptr acl_theta_scale_tensor;
|
acl_tensor_ptr acl_theta_scale_tensor;
|
||||||
// cache theta scale
|
bool theta_scale_updated = false;
|
||||||
if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
|
if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale ||
|
||||||
// theta_scale and freq_scale should not change during the current token inference process,
|
ctx.rope_cache.indep_sects != indep_sects) {
|
||||||
// so we can directly use == here instead of comparing the absolute difference.
|
theta_scale_updated = true;
|
||||||
ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) {
|
if (ctx.rope_cache.theta_scale_exp_host != nullptr) {
|
||||||
ctx.rope_cache.theta_scale_length = theta_scale_length;
|
free(ctx.rope_cache.theta_scale_exp_host);
|
||||||
|
}
|
||||||
|
ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float));
|
||||||
|
GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr);
|
||||||
|
if (!indep_sects) {
|
||||||
|
ctx.rope_cache.theta_scale_exp_host[0] = 1;
|
||||||
|
for (int i = 1; i < theta_scale_length; i++) {
|
||||||
|
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
||||||
|
int sec_w = sections[1] + sections[0];
|
||||||
|
int sec_e = sections[2] + sec_w;
|
||||||
|
|
||||||
|
ctx.rope_cache.theta_scale_exp_host[0] = 1;
|
||||||
|
for (int i = 1; i < theta_scale_length; i++) {
|
||||||
|
int sector = i % sect_dims;
|
||||||
|
if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) {
|
||||||
|
ctx.rope_cache.theta_scale_exp_host[i] = 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (ctx.rope_cache.theta_scale_cache != nullptr) {
|
if (ctx.rope_cache.theta_scale_cache != nullptr) {
|
||||||
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
|
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
|
||||||
|
|
@ -2286,74 +2328,138 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
||||||
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
|
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
|
||||||
ACL_MEM_MALLOC_HUGE_FIRST));
|
ACL_MEM_MALLOC_HUGE_FIRST));
|
||||||
|
|
||||||
|
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
|
||||||
|
ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
|
||||||
|
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
|
||||||
|
|
||||||
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
||||||
theta_scale_ne, theta_scale_nb, 1);
|
theta_scale_ne, theta_scale_nb, 1);
|
||||||
|
}
|
||||||
|
|
||||||
float start = 0;
|
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
|
||||||
float step = 1;
|
bool yarn_ramp_tensor_updated = false;
|
||||||
float stop = theta_scale_length;
|
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
|
||||||
float n_elements = theta_scale_length;
|
acl_tensor_ptr acl_yarn_ramp_tensor;
|
||||||
aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements);
|
if (ext_factor != 0 &&
|
||||||
|
// TODO: check more parameter.
|
||||||
|
(ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
|
||||||
|
yarn_ramp_tensor_updated = true;
|
||||||
|
|
||||||
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
|
// -rope_yarn_ramp
|
||||||
acl_tensor_ptr acl_yarn_ramp_tensor;
|
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
|
||||||
if (ext_factor != 0) {
|
// return MIN(1, MAX(0, y)) - 1;
|
||||||
// -rope_yarn_ramp
|
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
||||||
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
|
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
|
||||||
// return MIN(1, MAX(0, y)) - 1;
|
acl_yarn_ramp_tensor =
|
||||||
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
|
||||||
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
|
float zero_value = 0, one_value = 1;
|
||||||
acl_yarn_ramp_tensor =
|
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
||||||
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
|
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
||||||
float zero_value = 0, one_value = 1;
|
acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);
|
||||||
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
|
||||||
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
|
||||||
acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);
|
acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
|
||||||
acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
|
|
||||||
acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
|
|
||||||
acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
|
|
||||||
|
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(),
|
aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length);
|
||||||
acl_yarn_ramp_tensor.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get());
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());
|
||||||
|
|
||||||
// theta_interp = freq_scale * theta_extrap;
|
// theta_interp = freq_scale * theta_extrap;
|
||||||
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||||
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||||
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
|
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
|
||||||
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
|
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
|
||||||
//
|
//
|
||||||
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
|
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
|
||||||
// cache freq_scale + (freq_scale - 1) * ramp_mix
|
// cache freq_scale + (freq_scale - 1) * ramp_mix
|
||||||
float freq_scale_1 = freq_scale - 1;
|
float freq_scale_1 = freq_scale - 1;
|
||||||
acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);
|
acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);
|
||||||
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
|
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
// power
|
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
|
||||||
acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT);
|
if (ext_factor != 0) {
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(),
|
if (theta_scale_updated || yarn_ramp_tensor_updated) {
|
||||||
acl_theta_scale_tensor.get());
|
theta_scale_updated = true;
|
||||||
|
|
||||||
if (ext_factor != 0) {
|
|
||||||
aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get());
|
aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get());
|
||||||
} else if (freq_scale != 1) {
|
|
||||||
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// use cache
|
if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) {
|
||||||
|
theta_scale_updated = true;
|
||||||
|
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nothing changed, use cache.
|
||||||
|
if (!theta_scale_updated) {
|
||||||
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
||||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Step 1.4: prepare select index if mrope
|
||||||
|
acl_tensor_ptr position_select_index_tensor;
|
||||||
|
if (mrope_used) {
|
||||||
|
if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] ||
|
||||||
|
ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] ||
|
||||||
|
ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) {
|
||||||
|
if (ctx.rope_cache.position_select_index_host != nullptr) {
|
||||||
|
free(ctx.rope_cache.position_select_index_host);
|
||||||
|
}
|
||||||
|
ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int));
|
||||||
|
GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr);
|
||||||
|
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
||||||
|
int sec_w = sections[1] + sections[0];
|
||||||
|
int sec_e = sections[2] + sec_w;
|
||||||
|
// t,h,w,e
|
||||||
|
for (int i = 0; i < theta_scale_length; i++) {
|
||||||
|
int sector = i % sect_dims;
|
||||||
|
|
||||||
|
if (is_imrope) { // qwen3vl apply interleaved mrope
|
||||||
|
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
||||||
|
ctx.rope_cache.position_select_index_host[i] = 1;
|
||||||
|
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
||||||
|
ctx.rope_cache.position_select_index_host[i] = 2;
|
||||||
|
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
||||||
|
ctx.rope_cache.position_select_index_host[i] = 0;
|
||||||
|
} else {
|
||||||
|
ctx.rope_cache.position_select_index_host[i] = 3;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (sector >= sections[0] && sector < sec_w) {
|
||||||
|
ctx.rope_cache.position_select_index_host[i] = 1;
|
||||||
|
} else if (sector >= sec_w && sector < sec_e) {
|
||||||
|
ctx.rope_cache.position_select_index_host[i] = 2;
|
||||||
|
} else if (sector >= sec_e) {
|
||||||
|
ctx.rope_cache.position_select_index_host[i] = 3;
|
||||||
|
} else {
|
||||||
|
ctx.rope_cache.position_select_index_host[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx.rope_cache.position_select_index != nullptr) {
|
||||||
|
ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index));
|
||||||
|
}
|
||||||
|
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
|
||||||
|
ACL_MEM_MALLOC_HUGE_FIRST));
|
||||||
|
|
||||||
|
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
|
||||||
|
ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int),
|
||||||
|
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
|
||||||
|
}
|
||||||
|
|
||||||
|
position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32,
|
||||||
|
sizeof(int), theta_scale_ne, theta_scale_nb, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step2: divide by freq_factors
|
||||||
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
|
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
|
||||||
// freq_factors
|
|
||||||
if (src2) {
|
if (src2) {
|
||||||
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
|
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
|
||||||
void * freq_fac_res_ptr = freq_fac_res_allocator.get();
|
void * freq_fac_res_ptr = freq_fac_res_allocator.get();
|
||||||
|
|
@ -2366,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
||||||
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
|
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Step3: prepare position_tensor
|
||||||
|
acl_tensor_ptr acl_position_tensor;
|
||||||
|
ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool());
|
||||||
|
if (mrope_used) {
|
||||||
|
// Step3.1: select current position;
|
||||||
|
// position :
|
||||||
|
// pos1: [[0, 1 ,2 ,3 ],
|
||||||
|
// pos2: [4, 5 ,6 ,7 ],
|
||||||
|
// pos3: [8, 9 ,10,11],
|
||||||
|
// pos4: [12,13,14,15] ]
|
||||||
|
//
|
||||||
|
// select index = [0, 1, 2, 2, 1, 0]
|
||||||
|
//
|
||||||
|
// selected_tensor:
|
||||||
|
// [[0, 1 ,2 ,3 ],
|
||||||
|
// [4, 5 ,6 ,7 ],
|
||||||
|
// [8, 9 ,10,11],
|
||||||
|
// [8, 9 ,10,11],
|
||||||
|
// [4, 5 ,6 ,7 ],
|
||||||
|
// [0, 1 ,2 ,3 ]]
|
||||||
|
//
|
||||||
|
// transpose, from [seq_len:dims] to [dims:seq_len]
|
||||||
|
// [0, 4, 8 ,8 ,4, 0],
|
||||||
|
// [1, 5, 9, 9, 5, 1],
|
||||||
|
// [2, 6, 10,10,6 ,2],
|
||||||
|
// [3, 7, 11,11,7 3 ]]
|
||||||
|
//
|
||||||
|
// multipy by theta_scale_tensor
|
||||||
|
// [theta_scale^0, theta_scale^1, ..., theta_scale ^ n]
|
||||||
|
|
||||||
|
int64_t mrope_position_ne[] = { position_length, 4 };
|
||||||
|
size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) };
|
||||||
|
acl_tensor_ptr mrope_position =
|
||||||
|
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
|
||||||
|
mrope_position_ne, mrope_position_nb, 2);
|
||||||
|
|
||||||
|
// selected position tensor's shape is a transpose of cache tensor.
|
||||||
|
int64_t selected_position_ne[] = { position_length, theta_scale_length };
|
||||||
|
size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) };
|
||||||
|
mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float));
|
||||||
|
void * mrope_position_buffer = mrope_position_acllocator.get();
|
||||||
|
acl_position_tensor =
|
||||||
|
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
|
||||||
|
ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2);
|
||||||
|
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(),
|
||||||
|
acl_position_tensor.get());
|
||||||
|
|
||||||
|
// transpose
|
||||||
|
int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 };
|
||||||
|
size_t transposed_nb[GGML_MAX_DIMS];
|
||||||
|
transposed_nb[0] = sizeof(float);
|
||||||
|
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||||
|
transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::swap(transposed_ne[0], transposed_ne[2]);
|
||||||
|
std::swap(transposed_nb[0], transposed_nb[2]);
|
||||||
|
|
||||||
|
acl_position_tensor =
|
||||||
|
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
|
||||||
|
ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// auto bcast.
|
||||||
|
acl_position_tensor =
|
||||||
|
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
|
||||||
|
position_ne, position_nb, GGML_MAX_DIMS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step4: multiply by the position
|
||||||
|
int64_t theta_length = theta_scale_length * position_length;
|
||||||
|
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
|
||||||
|
void * theta_buffer = theta_allocator.get();
|
||||||
|
|
||||||
|
acl_tensor_ptr acl_theta_tensor =
|
||||||
|
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS);
|
||||||
|
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
|
||||||
|
|
||||||
|
// Step5: calculate sin cos.
|
||||||
// init sin_repeat && cos_repeat, only to accelerate first layer on each device
|
// init sin_repeat && cos_repeat, only to accelerate first layer on each device
|
||||||
if (position_length > ctx.rope_cache.position_length) {
|
if (position_length > ctx.rope_cache.position_length) {
|
||||||
ctx.rope_cache.position_length = position_length;
|
ctx.rope_cache.position_length = position_length;
|
||||||
|
|
@ -2382,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
||||||
aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
|
aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||||
}
|
}
|
||||||
|
|
||||||
// position
|
|
||||||
acl_tensor_ptr acl_position_tensor =
|
|
||||||
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne,
|
|
||||||
position_nb, GGML_MAX_DIMS);
|
|
||||||
|
|
||||||
// power * position
|
|
||||||
int64_t theta_length = theta_scale_length * position_length;
|
|
||||||
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
|
|
||||||
void * theta_buffer = theta_allocator.get();
|
|
||||||
|
|
||||||
acl_tensor_ptr acl_theta_tensor =
|
|
||||||
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS);
|
|
||||||
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
|
|
||||||
|
|
||||||
// sin/cos
|
// sin/cos
|
||||||
ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float));
|
ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float));
|
||||||
void * sin_buffer = sin_allocator.get();
|
void * sin_buffer = sin_allocator.get();
|
||||||
acl_tensor_ptr acl_sin_tensor =
|
acl_tensor_ptr acl_sin_tensor =
|
||||||
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
|
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||||
aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get());
|
aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get());
|
||||||
|
|
||||||
ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float));
|
ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float));
|
||||||
void * cos_buffer = cos_allocator.get();
|
void * cos_buffer = cos_allocator.get();
|
||||||
acl_tensor_ptr acl_cos_tensor =
|
acl_tensor_ptr acl_cos_tensor =
|
||||||
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
|
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||||
aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get());
|
aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get());
|
||||||
|
|
||||||
if (ext_factor != 0) {
|
if (ext_factor != 0) {
|
||||||
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// attn_factor
|
// Step 5: multiply by attn_factor
|
||||||
if (attn_factor != 1) {
|
if (attn_factor != 1) {
|
||||||
aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true);
|
aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true);
|
||||||
aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
|
aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 };
|
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
|
||||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||||
sin_reshape_nb[0] = sizeof(float);
|
sin_reshape_nb[0] = sizeof(float);
|
||||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||||
|
|
@ -2430,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
||||||
acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
|
acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
|
||||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||||
|
|
||||||
// repeat
|
// Step 6: repeat
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
|
// [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn]
|
||||||
int64_t repeatsArray[] = { 1, 1, 1, 2 };
|
int64_t repeatsArray[] = { 1, 1, 1, 2 };
|
||||||
aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray);
|
aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray);
|
||||||
aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray);
|
aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray);
|
||||||
|
|
@ -2439,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
||||||
int64_t num_repeats = 2;
|
int64_t num_repeats = 2;
|
||||||
int64_t dim = 3;
|
int64_t dim = 3;
|
||||||
int64_t output_size = theta_scale_length * num_repeats;
|
int64_t output_size = theta_scale_length * num_repeats;
|
||||||
|
// [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn]
|
||||||
aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size);
|
aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size);
|
||||||
aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size);
|
aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Other layers use cache except first layer.
|
// Update cached value.
|
||||||
ctx.rope_cache.cached = true;
|
ctx.rope_cache.cached = true;
|
||||||
ctx.rope_cache.ext_factor = ext_factor;
|
ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox,
|
||||||
ctx.rope_cache.theta_scale = theta_scale;
|
indep_sects, mrope_used, is_imrope, sections);
|
||||||
ctx.rope_cache.freq_scale = freq_scale;
|
|
||||||
ctx.rope_cache.attn_factor = attn_factor;
|
|
||||||
ctx.rope_cache.is_neox = is_neox;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
@ -2475,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
// param
|
// param
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
|
int sections[4];
|
||||||
// const int n_past = ((int32_t *) dst->op_params)[0];
|
// const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
|
|
@ -2483,12 +2654,13 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
GGML_TENSOR_UNARY_OP_LOCALS
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||||
|
|
||||||
// TODO: n_dims <= ne0
|
// TODO: n_dims <= ne0
|
||||||
GGML_ASSERT(n_dims == ne0);
|
GGML_ASSERT(n_dims == ne0);
|
||||||
|
|
@ -2499,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
|
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
||||||
|
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
||||||
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||||
|
|
||||||
|
if (mrope_used) {
|
||||||
|
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_vision) {
|
||||||
|
GGML_ASSERT(n_dims == ne0/2);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_imrope || mrope_used) {
|
||||||
|
is_neox = true;
|
||||||
|
}
|
||||||
|
|
||||||
// init ctx.rope_cos/rope_sin cache
|
// init ctx.rope_cos/rope_sin cache
|
||||||
aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox);
|
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
|
||||||
|
|
||||||
int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
|
int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
|
||||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||||
|
|
@ -2658,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// ggml_mode = 0 --> aclnn_model = 1
|
int64_t acl_mode = is_neox ? 0 : 1;
|
||||||
int64_t acl_mode = mode == 0 ? 1 : mode;
|
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
|
|
||||||
|
|
@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache {
|
||||||
|
|
||||||
struct ggml_cann_rope_cache {
|
struct ggml_cann_rope_cache {
|
||||||
~ggml_cann_rope_cache() {
|
~ggml_cann_rope_cache() {
|
||||||
if (theta_scale_cache != nullptr) {
|
if (theta_scale_cache) {
|
||||||
ACL_CHECK(aclrtFree(theta_scale_cache));
|
ACL_CHECK(aclrtFree(theta_scale_cache));
|
||||||
}
|
}
|
||||||
if (sin_cache != nullptr) {
|
if (sin_cache) {
|
||||||
ACL_CHECK(aclrtFree(sin_cache));
|
ACL_CHECK(aclrtFree(sin_cache));
|
||||||
}
|
}
|
||||||
if (cos_cache != nullptr) {
|
if (cos_cache) {
|
||||||
ACL_CHECK(aclrtFree(cos_cache));
|
ACL_CHECK(aclrtFree(cos_cache));
|
||||||
}
|
}
|
||||||
|
if (position_select_index) {
|
||||||
|
ACL_CHECK(aclrtFree(position_select_index));
|
||||||
|
}
|
||||||
|
if (theta_scale_exp_host) {
|
||||||
|
free(theta_scale_exp_host);
|
||||||
|
}
|
||||||
|
if(position_select_index_host) {
|
||||||
|
free(position_select_index_host);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void * theta_scale_cache = nullptr;
|
bool equal(int64_t theta_scale_length,
|
||||||
int64_t theta_scale_length = 0;
|
int64_t position_length,
|
||||||
|
float ext_factor,
|
||||||
|
float theta_scale,
|
||||||
|
float freq_scale,
|
||||||
|
float attn_factor,
|
||||||
|
bool is_neox,
|
||||||
|
bool indep_sects,
|
||||||
|
bool mrope_used,
|
||||||
|
bool is_imrope,
|
||||||
|
int sections[4]) {
|
||||||
|
return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
|
||||||
|
this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
|
||||||
|
this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
|
||||||
|
this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
|
||||||
|
this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
void set(int64_t theta_scale_length,
|
||||||
|
int64_t position_length,
|
||||||
|
float ext_factor,
|
||||||
|
float theta_scale,
|
||||||
|
float freq_scale,
|
||||||
|
float attn_factor,
|
||||||
|
bool is_neox,
|
||||||
|
bool indep_sects,
|
||||||
|
bool mrope_used,
|
||||||
|
bool is_imrope,
|
||||||
|
int sections[4]) {
|
||||||
|
this->theta_scale_length = theta_scale_length;
|
||||||
|
this->position_length = position_length;
|
||||||
|
this->ext_factor = ext_factor;
|
||||||
|
this->theta_scale = theta_scale;
|
||||||
|
this->freq_scale = freq_scale;
|
||||||
|
this->attn_factor = attn_factor;
|
||||||
|
this->is_neox = is_neox;
|
||||||
|
this->indep_sects = indep_sects;
|
||||||
|
this->mrope_used = mrope_used;
|
||||||
|
this->is_imrope = is_imrope;
|
||||||
|
this->sections[0] = sections[0];
|
||||||
|
this->sections[1] = sections[1];
|
||||||
|
this->sections[2] = sections[2];
|
||||||
|
this->sections[3] = sections[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
// memory cache, prepare before inferencing.
|
||||||
|
void * theta_scale_cache = nullptr;
|
||||||
|
float * theta_scale_exp_host = nullptr;
|
||||||
|
int * position_select_index_host = nullptr;
|
||||||
|
void * position_select_index = nullptr;
|
||||||
// sin/cos cache, used only to accelerate first layer on each device
|
// sin/cos cache, used only to accelerate first layer on each device
|
||||||
void * sin_cache = nullptr;
|
void * sin_cache = nullptr;
|
||||||
void * cos_cache = nullptr;
|
void * cos_cache = nullptr;
|
||||||
int64_t position_length = 0;
|
|
||||||
// Properties to check before reusing the sincos cache
|
// Properties to check before reusing the sincos cache
|
||||||
bool cached = false;
|
int64_t theta_scale_length = 0;
|
||||||
float ext_factor = 0.0f;
|
int64_t position_length = 0;
|
||||||
float theta_scale = 0.0f;
|
bool cached = false;
|
||||||
float freq_scale = 0.0f;
|
float ext_factor = 0.0f;
|
||||||
float attn_factor = 0.0f;
|
float theta_scale = 0.0f;
|
||||||
bool is_neox = false;
|
float freq_scale = 0.0f;
|
||||||
|
float attn_factor = 0.0f;
|
||||||
|
bool is_neox = false;
|
||||||
|
bool indep_sects = false;
|
||||||
|
bool mrope_used = false;
|
||||||
|
int sections[4] = { 0, 0, 0, 0 };
|
||||||
|
bool is_imrope = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_cann_tensor_cache {
|
struct ggml_cann_tensor_cache {
|
||||||
|
|
|
||||||
|
|
@ -2480,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int mode = ((const int32_t *) op->op_params)[2];
|
|
||||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (op->src[0]->ne[0] > 896) {
|
if (op->src[0]->ne[0] > 896) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
|
|
||||||
include(CheckCXXSourceCompiles)
|
include(CheckCXXSourceCompiles)
|
||||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||||
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}")
|
string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}")
|
||||||
|
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}")
|
||||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||||
set(ARM_FEATURE "HAVE_${feature}")
|
set(ARM_FEATURE "HAVE_${feature}")
|
||||||
check_cxx_source_compiles(
|
check_cxx_source_compiles(
|
||||||
|
|
|
||||||
|
|
@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_argsort(params, tensor);
|
ggml_compute_forward_argsort(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_top_k(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_leaky_relu(params, tensor);
|
ggml_compute_forward_leaky_relu(params, tensor);
|
||||||
|
|
@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
|
|
@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
|
||||||
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
||||||
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
{
|
||||||
|
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
|
||||||
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
||||||
|
|
|
||||||
|
|
@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
|
||||||
// ggml_compute_forward_argsort
|
// ggml_compute_forward_argsort
|
||||||
|
|
||||||
template<enum ggml_sort_order order>
|
template<enum ggml_sort_order order>
|
||||||
struct argsort_cmp {
|
struct cmp_argsort {
|
||||||
const float * data;
|
const float * data;
|
||||||
bool operator()(int32_t a, int32_t b) const {
|
bool operator()(int32_t a, int32_t b) const {
|
||||||
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
|
@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
|
||||||
|
|
||||||
switch (order) {
|
switch (order) {
|
||||||
case GGML_SORT_ORDER_ASC:
|
case GGML_SORT_ORDER_ASC:
|
||||||
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
|
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case GGML_SORT_ORDER_DESC:
|
case GGML_SORT_ORDER_DESC:
|
||||||
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
|
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
|
@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_top_k
|
||||||
|
|
||||||
|
struct cmp_top_k {
|
||||||
|
const float * data;
|
||||||
|
bool operator()(int32_t a, int32_t b) const {
|
||||||
|
return data[a] > data[b];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static void ggml_compute_forward_top_k_f32(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
|
GGML_ASSERT(nb0 == sizeof(float));
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int64_t nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
const int top_k = ne0;
|
||||||
|
|
||||||
|
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
||||||
|
|
||||||
|
for (int64_t i = ith; i < nr; i += nth) {
|
||||||
|
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < ne00; j++) {
|
||||||
|
tmp[j] = j;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
|
||||||
|
|
||||||
|
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
||||||
|
|
||||||
|
std::copy(tmp, tmp + top_k, dst_data);
|
||||||
|
|
||||||
|
// emphasize that the order is not important
|
||||||
|
if (top_k > 1) {
|
||||||
|
std::swap(dst_data[0], dst_data[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_top_k(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_top_k_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_flash_attn_ext
|
// ggml_compute_forward_flash_attn_ext
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
|
||||||
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
|
|
||||||
|
|
@ -455,146 +455,141 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
|
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
|
||||||
#if defined(GGML_SIMD)
|
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
const int sve_register_length = svcntb() * 8;
|
||||||
const int sve_register_length = svcntb() * 8;
|
const int ggml_f16_epr = sve_register_length / 16;
|
||||||
const int ggml_f16_epr = sve_register_length / 16;
|
const int ggml_f16_step = 8 * ggml_f16_epr;
|
||||||
const int ggml_f16_step = 8 * ggml_f16_epr;
|
|
||||||
|
|
||||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||||
|
|
||||||
const int np= (n & ~(ggml_f16_step - 1));
|
int np = (n & ~(ggml_f16_step - 1));
|
||||||
|
|
||||||
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
|
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
|
||||||
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
|
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
|
||||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||||
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
|
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
|
||||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
|
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
|
||||||
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
|
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
|
||||||
|
|
||||||
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
|
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
|
||||||
|
|
||||||
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
|
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
|
||||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
|
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
|
||||||
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
|
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
|
||||||
|
|
||||||
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
|
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
|
||||||
|
|
||||||
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
|
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
|
||||||
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
|
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
|
||||||
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
|
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
|
||||||
|
|
||||||
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
|
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
|
||||||
|
|
||||||
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
|
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
|
||||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||||
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
|
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
|
||||||
|
|
||||||
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
|
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
|
||||||
|
|
||||||
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
|
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
|
||||||
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
|
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
|
||||||
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
|
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
|
||||||
|
|
||||||
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
|
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
|
||||||
|
|
||||||
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
|
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
|
||||||
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
|
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
|
||||||
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
|
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
|
||||||
|
|
||||||
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
|
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
|
||||||
|
|
||||||
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
|
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
|
||||||
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
|
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
|
||||||
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
|
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
|
||||||
|
|
||||||
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
|
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
|
||||||
|
|
||||||
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
|
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
|
||||||
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
|
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
|
||||||
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
|
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
|
||||||
|
|
||||||
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
|
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
|
||||||
|
}
|
||||||
|
const int np2 = (n & ~(ggml_f16_epr - 1));
|
||||||
|
for (int k = np; k < np2; k += ggml_f16_epr) {
|
||||||
|
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
|
||||||
|
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
|
||||||
|
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
|
||||||
|
|
||||||
|
GGML_F16x_VEC_STORE(y + k, ry, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (np2 < n) {
|
||||||
|
svbool_t pg = svwhilelt_b16(np2, n);
|
||||||
|
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
|
||||||
|
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||||
|
hy = svmad_f16_x(pg, hx, vx, hy);
|
||||||
|
svst1_f16(pg, (__fp16 *)(y + np2), hy);
|
||||||
|
}
|
||||||
|
np = n;
|
||||||
|
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
|
||||||
|
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||||
|
const _Float16 scale = *(const _Float16*)(&s);
|
||||||
|
|
||||||
|
// calculate step size
|
||||||
|
const int epr = __riscv_vsetvlmax_e16m4();
|
||||||
|
const int step = epr * 2;
|
||||||
|
const int np = (n & ~(step - 1));
|
||||||
|
|
||||||
|
// unroll by 2
|
||||||
|
for (int i = 0; i < np; i += step) {
|
||||||
|
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
|
||||||
|
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||||
|
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
|
||||||
|
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||||
|
__asm__ __volatile__ ("" ::: "memory");
|
||||||
|
|
||||||
|
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
|
||||||
|
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||||
|
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
|
||||||
|
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||||
|
__asm__ __volatile__ ("" ::: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
int vl;
|
||||||
|
for (int i = np; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e16m4(n - i);
|
||||||
|
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i , vl);
|
||||||
|
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||||
|
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
|
||||||
|
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||||
|
}
|
||||||
|
#elif defined(GGML_SIMD)
|
||||||
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
||||||
|
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||||
|
|
||||||
|
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||||
|
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||||
|
|
||||||
|
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||||
|
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||||
|
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||||
|
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||||
|
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
||||||
|
|
||||||
|
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||||
}
|
}
|
||||||
const int np2 = (n & ~(ggml_f16_epr - 1));
|
}
|
||||||
for (int k = np; k < np2; k += ggml_f16_epr) {
|
|
||||||
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
|
|
||||||
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
|
|
||||||
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
|
|
||||||
|
|
||||||
GGML_F16x_VEC_STORE(y + k, ry, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (np2 < n) {
|
|
||||||
svbool_t pg = svwhilelt_b16(np2, n);
|
|
||||||
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
|
|
||||||
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
|
|
||||||
hy = svmad_f16_x(pg, hx, vx, hy);
|
|
||||||
svst1_f16(pg, (__fp16 *)(y + np2), hy);
|
|
||||||
}
|
|
||||||
|
|
||||||
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
|
||||||
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
|
||||||
const _Float16 scale = *(const _Float16*)(&s);
|
|
||||||
|
|
||||||
// calculate step size
|
|
||||||
const int epr = __riscv_vsetvlmax_e16m4();
|
|
||||||
const int step = epr * 2;
|
|
||||||
const int np = (n & ~(step - 1));
|
|
||||||
|
|
||||||
// unroll by 2
|
|
||||||
for (int i = 0; i < np; i += step) {
|
|
||||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
|
|
||||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
|
||||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
|
|
||||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
|
||||||
__asm__ __volatile__ ("" ::: "memory");
|
|
||||||
|
|
||||||
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
|
|
||||||
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
|
||||||
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
|
|
||||||
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
|
||||||
__asm__ __volatile__ ("" ::: "memory");
|
|
||||||
}
|
|
||||||
|
|
||||||
// leftovers
|
|
||||||
int vl;
|
|
||||||
for (int i = np; i < n; i += vl) {
|
|
||||||
vl = __riscv_vsetvl_e16m4(n - i);
|
|
||||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i , vl);
|
|
||||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
|
||||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
|
|
||||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
|
||||||
|
|
||||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
|
||||||
|
|
||||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
|
||||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
|
||||||
|
|
||||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
|
||||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
|
||||||
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
|
||||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
|
||||||
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
|
||||||
|
|
||||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// leftovers
|
|
||||||
for (int i = np; i < n; ++i) {
|
|
||||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#else
|
#else
|
||||||
// scalar
|
const int np = 0;
|
||||||
for (int i = 0; i < n; ++i) {
|
#endif
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
for (int i = np; i < n; ++i) {
|
||||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// xs and vs are byte strides of x and v
|
// xs and vs are byte strides of x and v
|
||||||
|
|
|
||||||
|
|
@ -437,18 +437,27 @@ namespace ggml_cuda_mma {
|
||||||
xi[0] = xs[0];
|
xi[0] = xs[0];
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE)
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
if constexpr (I == 16 && J == 4) {
|
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||||
int64_t * xi = (int64_t *) t.x;
|
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
|
||||||
xi[0] = xs[0];
|
|
||||||
}else if constexpr (I == 16 && J == 8) {
|
|
||||||
int64_t * xi = (int64_t *) t.x;
|
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
|
||||||
xi[0] = xs[0];
|
|
||||||
|
|
||||||
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
} else if constexpr (std::is_same_v<T, int>) {
|
||||||
xi[1] = xs1[0];
|
if constexpr (I == 16 && J == 4) {
|
||||||
}else{
|
int64_t * xi = (int64_t *) t.x;
|
||||||
|
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||||
|
xi[0] = xs[0];
|
||||||
|
|
||||||
|
}else if constexpr (I == 16 && J == 8) {
|
||||||
|
int64_t * xi = (int64_t *) t.x;
|
||||||
|
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
||||||
|
xi[0] = xs[0];
|
||||||
|
|
||||||
|
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
||||||
|
xi[1] = xs1[0];
|
||||||
|
|
||||||
|
}else{
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
|
|
||||||
|
|
@ -3701,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
|
||||||
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
||||||
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
||||||
const size_t nbs_ids = mmq_x*sizeof(int);
|
const size_t nbs_ids = mmq_x*sizeof(int);
|
||||||
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
||||||
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
||||||
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// note: reuse the argsort kernel for top_k
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
|
assert(op->op == GGML_OP_TOP_K);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
// note: the top_k kernel is always descending order
|
||||||
|
ggml_sort_order order = GGML_SORT_ORDER_DESC;
|
||||||
|
|
||||||
|
const char * order_str = "undefined";
|
||||||
|
switch (order) {
|
||||||
|
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
||||||
|
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
||||||
|
default: GGML_ABORT("fatal error");
|
||||||
|
};
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
||||||
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (res) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
|
assert(op->op == GGML_OP_TOP_K);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
ggml_sort_order order = GGML_SORT_ORDER_DESC;
|
||||||
|
|
||||||
|
const char * order_str = "undefined";
|
||||||
|
switch (order) {
|
||||||
|
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
||||||
|
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
||||||
|
default: GGML_ABORT("fatal error");
|
||||||
|
};
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
||||||
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (res) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||||
ggml_metal_library_t lib,
|
ggml_metal_library_t lib,
|
||||||
const struct ggml_tensor * op,
|
const struct ggml_tensor * op,
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
|
|
||||||
|
|
@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
|
|
||||||
|
|
@ -832,14 +832,19 @@ typedef struct {
|
||||||
} ggml_metal_kargs_leaky_relu;
|
} ggml_metal_kargs_leaky_relu;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t ne00;
|
int32_t ne00;
|
||||||
int64_t ne01;
|
int32_t ne01;
|
||||||
int64_t ne02;
|
int32_t ne02;
|
||||||
int64_t ne03;
|
int32_t ne03;
|
||||||
uint64_t nb00;
|
uint64_t nb00;
|
||||||
uint64_t nb01;
|
uint64_t nb01;
|
||||||
uint64_t nb02;
|
uint64_t nb02;
|
||||||
uint64_t nb03;
|
uint64_t nb03;
|
||||||
|
int32_t ne0;
|
||||||
|
int32_t ne1;
|
||||||
|
int32_t ne2;
|
||||||
|
int32_t ne3;
|
||||||
|
int32_t top_k;
|
||||||
} ggml_metal_kargs_argsort;
|
} ggml_metal_kargs_argsort;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
|
@ -851,6 +856,11 @@ typedef struct {
|
||||||
uint64_t nb01;
|
uint64_t nb01;
|
||||||
uint64_t nb02;
|
uint64_t nb02;
|
||||||
uint64_t nb03;
|
uint64_t nb03;
|
||||||
|
int32_t ne0;
|
||||||
|
int32_t ne1;
|
||||||
|
int32_t ne2;
|
||||||
|
int32_t ne3;
|
||||||
|
int32_t top_k;
|
||||||
int32_t len;
|
int32_t len;
|
||||||
} ggml_metal_kargs_argsort_merge;
|
} ggml_metal_kargs_argsort_merge;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_argsort(ctx, idx);
|
n_fuse = ggml_metal_op_argsort(ctx, idx);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
{
|
||||||
|
n_fuse = ggml_metal_op_top_k(ctx, idx);
|
||||||
|
} break;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
{
|
{
|
||||||
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
||||||
|
|
@ -3678,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_kargs_argsort args = {
|
ggml_metal_kargs_argsort args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
/*.ne01 =*/ ne01,
|
/*.ne01 =*/ ne01,
|
||||||
/*.ne02 =*/ ne02,
|
/*.ne02 =*/ ne02,
|
||||||
/*.ne03 =*/ ne03,
|
/*.ne03 =*/ ne03,
|
||||||
/*.nb00 =*/ nb00,
|
/*.nb00 =*/ nb00,
|
||||||
/*.nb01 =*/ nb01,
|
/*.nb01 =*/ nb01,
|
||||||
/*.nb02 =*/ nb02,
|
/*.nb02 =*/ nb02,
|
||||||
/*.nb03 =*/ nb03,
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.ne2 =*/ ne2,
|
||||||
|
/*.ne3 =*/ ne3,
|
||||||
|
/*.top_k =*/ nth,
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
|
|
@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_metal_op_concurrency_reset(ctx);
|
ggml_metal_op_concurrency_reset(ctx);
|
||||||
|
|
||||||
ggml_metal_kargs_argsort_merge args_merge = {
|
ggml_metal_kargs_argsort_merge args_merge = {
|
||||||
.ne00 = ne00,
|
/*.ne00 =*/ ne00,
|
||||||
.ne01 = ne01,
|
/*.ne01 =*/ ne01,
|
||||||
.ne02 = ne02,
|
/*.ne02 =*/ ne02,
|
||||||
.ne03 = ne03,
|
/*.ne03 =*/ ne03,
|
||||||
.nb00 = nb00,
|
/*.nb00 =*/ nb00,
|
||||||
.nb01 = nb01,
|
/*.nb01 =*/ nb01,
|
||||||
.nb02 = nb02,
|
/*.nb02 =*/ nb02,
|
||||||
.nb03 = nb03,
|
/*.nb03 =*/ nb03,
|
||||||
.len = len,
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.ne2 =*/ ne2,
|
||||||
|
/*.ne3 =*/ ne3,
|
||||||
|
/*.top_k =*/ ne00,
|
||||||
|
/*.len =*/ len,
|
||||||
};
|
};
|
||||||
|
|
||||||
// merges per row
|
// merges per row
|
||||||
|
|
@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
||||||
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
ggml_metal_library_t lib = ctx->lib;
|
||||||
|
ggml_metal_encoder_t enc = ctx->enc;
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||||
|
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
|
||||||
|
|
||||||
|
// bitonic sort requires the number of elements to be power of 2
|
||||||
|
int nth = 1;
|
||||||
|
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// blocks per row
|
||||||
|
const int npr = (ne00 + nth - 1)/nth;
|
||||||
|
|
||||||
|
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||||
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||||
|
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
|
||||||
|
|
||||||
|
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
||||||
|
std::swap(bid_dst, bid_tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int top_k = ne0;
|
||||||
|
|
||||||
|
ggml_metal_kargs_argsort args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne01 =*/ ne01,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.ne03 =*/ ne03,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.ne2 =*/ ne2,
|
||||||
|
/*.ne3 =*/ ne3,
|
||||||
|
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
|
||||||
|
};
|
||||||
|
|
||||||
|
if (npr > 1) {
|
||||||
|
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
|
||||||
|
|
||||||
|
int len = args.top_k;
|
||||||
|
|
||||||
|
while (len < args.ne0) {
|
||||||
|
ggml_metal_op_concurrency_reset(ctx);
|
||||||
|
|
||||||
|
// merges per row
|
||||||
|
const int nm = (args.ne0 + 2*len - 1) / (2*len);
|
||||||
|
|
||||||
|
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
|
||||||
|
|
||||||
|
ggml_metal_kargs_argsort_merge args_merge = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne01 =*/ ne01,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.ne03 =*/ ne03,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne0 =*/ args.ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.ne2 =*/ ne2,
|
||||||
|
/*.ne3 =*/ ne3,
|
||||||
|
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
|
||||||
|
/*.len =*/ len,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
||||||
|
|
||||||
|
std::swap(bid_dst, bid_tmp);
|
||||||
|
|
||||||
|
len <<= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_tensor * op = ctx->node(idx);
|
ggml_tensor * op = ctx->node(idx);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||||
|
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||||
|
|
|
||||||
|
|
@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
||||||
{
|
{
|
||||||
res *= 2;
|
res *= 2;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
{
|
||||||
|
res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32(
|
||||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
const int col = tpitg[0];
|
const int col = tpitg[0];
|
||||||
|
const int ib = tgpig[0] / args.ne01;
|
||||||
|
|
||||||
const int i00 = (tgpig[0]/args.ne01)*ntg.x;
|
const int i00 = ib*ntg.x;
|
||||||
const int i01 = tgpig[0]%args.ne01;
|
const int i01 = tgpig[0] % args.ne01;
|
||||||
const int i02 = tgpig[1];
|
const int i02 = tgpig[1];
|
||||||
const int i03 = tgpig[2];
|
const int i03 = tgpig[2];
|
||||||
|
|
||||||
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
||||||
|
|
||||||
|
|
@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64_t i0 = ib*args.top_k;
|
||||||
|
|
||||||
// copy the result to dst without the padding
|
// copy the result to dst without the padding
|
||||||
if (i00 + col < args.ne00) {
|
if (i0 + col < args.ne0 && col < args.top_k) {
|
||||||
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
|
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
|
||||||
|
|
||||||
dst[col] = shmem_i32[col];
|
dst[col] = shmem_i32[col];
|
||||||
}
|
}
|
||||||
|
|
@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32(
|
||||||
|
|
||||||
const int start = im * (2 * args.len);
|
const int start = im * (2 * args.len);
|
||||||
|
|
||||||
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
|
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
||||||
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
|
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
||||||
|
|
||||||
const int total = len0 + len1;
|
const int total = len0 + len1;
|
||||||
|
|
||||||
device const int32_t * tmp0 = tmp + start
|
device const int32_t * tmp0 = tmp + start
|
||||||
+ i01*args.ne00
|
+ i01*args.ne0
|
||||||
+ i02*args.ne00*args.ne01
|
+ i02*args.ne0*args.ne01
|
||||||
+ i03*args.ne00*args.ne01*args.ne02;
|
+ i03*args.ne0*args.ne01*args.ne02;
|
||||||
|
|
||||||
device const int32_t * tmp1 = tmp0 + args.len;
|
device const int32_t * tmp1 = tmp0 + args.len;
|
||||||
|
|
||||||
dst += start
|
dst += start
|
||||||
+ i01*args.ne00
|
+ i01*args.top_k
|
||||||
+ i02*args.ne00*args.ne01
|
+ i02*args.top_k*args.ne01
|
||||||
+ i03*args.ne00*args.ne01*args.ne02;
|
+ i03*args.top_k*args.ne01*args.ne02;
|
||||||
|
|
||||||
device const float * src0_row = (device const float *)(src0
|
device const float * src0_row = (device const float *)(src0
|
||||||
+ args.nb01*i01
|
+ args.nb01*i01
|
||||||
|
|
@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32(
|
||||||
const int chunk = (total + ntg.x - 1) / ntg.x;
|
const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||||
|
|
||||||
const int k0 = tpitg.x * chunk;
|
const int k0 = tpitg.x * chunk;
|
||||||
const int k1 = min(k0 + chunk, total);
|
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
||||||
|
|
||||||
|
if (k0 >= args.top_k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (k0 >= total) {
|
if (k0 >= total) {
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
|
|
@ -409,6 +409,7 @@ enum shader_reduction_mode {
|
||||||
// argsort pipelines for up to 1<<10 invocations per workgroup
|
// argsort pipelines for up to 1<<10 invocations per workgroup
|
||||||
static constexpr uint32_t num_argsort_pipelines = 11;
|
static constexpr uint32_t num_argsort_pipelines = 11;
|
||||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||||
|
static constexpr uint32_t num_topk_pipelines = 11;
|
||||||
|
|
||||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
|
|
@ -515,6 +516,7 @@ struct vk_device_struct {
|
||||||
bool single_queue;
|
bool single_queue;
|
||||||
bool support_async;
|
bool support_async;
|
||||||
uint32_t subgroup_size;
|
uint32_t subgroup_size;
|
||||||
|
uint32_t subgroup_size_log2;
|
||||||
uint32_t shader_core_count;
|
uint32_t shader_core_count;
|
||||||
bool uma;
|
bool uma;
|
||||||
bool prefer_host_memory;
|
bool prefer_host_memory;
|
||||||
|
|
@ -704,7 +706,9 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
||||||
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
||||||
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||||
|
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
||||||
vk_pipeline pipeline_sum_rows_f32;
|
vk_pipeline pipeline_sum_rows_f32;
|
||||||
|
vk_pipeline pipeline_cumsum_f32;
|
||||||
vk_pipeline pipeline_argmax_f32;
|
vk_pipeline pipeline_argmax_f32;
|
||||||
vk_pipeline pipeline_count_equal_i32;
|
vk_pipeline pipeline_count_equal_i32;
|
||||||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||||
|
|
@ -1204,6 +1208,15 @@ struct vk_op_argsort_push_constants {
|
||||||
uint32_t inner_end;
|
uint32_t inner_end;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct vk_op_topk_push_constants {
|
||||||
|
uint32_t orig_ncols;
|
||||||
|
uint32_t ncols_input;
|
||||||
|
uint32_t ncols_output;
|
||||||
|
uint32_t nrows;
|
||||||
|
uint32_t first_pass;
|
||||||
|
uint32_t last_pass;
|
||||||
|
};
|
||||||
|
|
||||||
struct vk_op_im2col_push_constants {
|
struct vk_op_im2col_push_constants {
|
||||||
uint64_t dst_addr;
|
uint64_t dst_addr;
|
||||||
uint32_t batch_offset; uint32_t offset_delta;
|
uint32_t batch_offset; uint32_t offset_delta;
|
||||||
|
|
@ -3964,10 +3977,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
|
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
|
||||||
|
const uint32_t BLOCK_SIZE = 1u << i;
|
||||||
|
const uint32_t NCOLS_PADDED_LOG2 = i;
|
||||||
|
if (i <= device->max_workgroup_size_log2) {
|
||||||
|
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
|
||||||
|
sizeof(int) * device->subgroup_size +
|
||||||
|
2 * sizeof(int) +
|
||||||
|
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
|
||||||
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
|
||||||
|
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
|
||||||
|
} else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
#define IM2COL(bda) \
|
#define IM2COL(bda) \
|
||||||
|
|
@ -4333,6 +4365,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
||||||
|
|
||||||
device->subgroup_size = subgroup_props.subgroupSize;
|
device->subgroup_size = subgroup_props.subgroupSize;
|
||||||
|
device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
|
||||||
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||||
if (sm_builtins) {
|
if (sm_builtins) {
|
||||||
device->shader_core_count = sm_props.shaderSMCount;
|
device->shader_core_count = sm_props.shaderSMCount;
|
||||||
|
|
@ -8457,6 +8490,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_sum_rows_f32;
|
return ctx->device->pipeline_sum_rows_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_CUMSUM:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_cumsum_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||||
return ctx->device->pipeline_argmax_f32;
|
return ctx->device->pipeline_argmax_f32;
|
||||||
|
|
@ -8821,6 +8859,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_CUMSUM:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
{
|
{
|
||||||
|
|
@ -10134,6 +10173,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
uint32_t ncols = src0->ne[0];
|
||||||
|
uint32_t nrows = ggml_nrows(src0);
|
||||||
|
uint32_t k = dst->ne[0];
|
||||||
|
|
||||||
|
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
|
||||||
|
|
||||||
|
// Reserve space for ivec2 per element, double buffered
|
||||||
|
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
|
||||||
|
const size_t x_sz = dbl_buf_size * 2;
|
||||||
|
uint32_t dbl_buf_index = 0;
|
||||||
|
|
||||||
|
if (ctx->prealloc_size_x < x_sz) {
|
||||||
|
ctx->prealloc_size_x = x_sz;
|
||||||
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
if (ctx->prealloc_x_need_sync) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<uint32_t, 3> elements;
|
||||||
|
elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||||
|
elements[2] = 1;
|
||||||
|
|
||||||
|
uint32_t num_elements = ncols;
|
||||||
|
|
||||||
|
// Each iteration reduces a workgroup's worth of elements down to the K
|
||||||
|
// largest elements. Repeat until we have the top K elements.
|
||||||
|
// Need to do at least one iteration to write out the results.
|
||||||
|
bool done_one_iter = false;
|
||||||
|
while (num_elements > k || !done_one_iter) {
|
||||||
|
done_one_iter = true;
|
||||||
|
|
||||||
|
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
|
||||||
|
// But if K is larger, then we need a larger workgroup
|
||||||
|
uint32_t max_pipeline = num_topk_pipelines - 3;
|
||||||
|
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
|
||||||
|
// require full subgroup
|
||||||
|
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
|
||||||
|
|
||||||
|
uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
|
||||||
|
pipeline_idx = std::min(pipeline_idx, max_pipeline);
|
||||||
|
pipeline_idx = std::max(pipeline_idx, min_pipeline);
|
||||||
|
|
||||||
|
if (num_elements > (1u << pipeline_idx)) {
|
||||||
|
// If we could finish on this loop iteration (i.e. a single workgroup)
|
||||||
|
// then do so. It's better than the overhead of another pass.
|
||||||
|
for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
|
||||||
|
if (num_elements <= (1u << i)) {
|
||||||
|
pipeline_idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
|
||||||
|
// If the device doesn't support a pipeline this large, use smaller
|
||||||
|
while (!pipeline) {
|
||||||
|
pipeline_idx--;
|
||||||
|
GGML_ASSERT(pipeline_idx >= min_pipeline);
|
||||||
|
pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_op_topk_push_constants pc2 = pc;
|
||||||
|
pc2.ncols_input = num_elements;
|
||||||
|
|
||||||
|
// Number of elements remaining after this pass
|
||||||
|
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
|
||||||
|
|
||||||
|
vk_subbuffer src_buf;
|
||||||
|
vk_subbuffer dst_buf;
|
||||||
|
|
||||||
|
if (num_elements == ncols) {
|
||||||
|
pc2.first_pass = 1;
|
||||||
|
src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
|
||||||
|
} else {
|
||||||
|
src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
|
||||||
|
}
|
||||||
|
if (num_dst_elements == k) {
|
||||||
|
pc2.last_pass = 1;
|
||||||
|
dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||||
|
} else {
|
||||||
|
dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
|
||||||
|
}
|
||||||
|
|
||||||
|
elements[0] = num_elements;
|
||||||
|
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
|
||||||
|
num_elements = num_dst_elements;
|
||||||
|
dbl_buf_index ^= 1;
|
||||||
|
if (num_elements > k) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx->prealloc_x_need_sync = true;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
|
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
|
||||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
|
||||||
|
|
@ -10150,6 +10287,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
||||||
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
|
||||||
}
|
}
|
||||||
|
|
@ -11741,6 +11883,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
ggml_vk_topk(ctx, compute_ctx, src0, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
ggml_vk_sum(ctx, compute_ctx, src0, node);
|
ggml_vk_sum(ctx, compute_ctx, src0, node);
|
||||||
|
|
@ -11749,6 +11895,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
|
ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_CUMSUM:
|
||||||
|
ggml_vk_cumsum(ctx, compute_ctx, src0, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
ggml_vk_mean(ctx, compute_ctx, src0, node);
|
ggml_vk_mean(ctx, compute_ctx, src0, node);
|
||||||
|
|
@ -13008,24 +13158,6 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// This function tries to reorder the graph to allow nodes to run in parallel.
|
|
||||||
// This helps with small batches, but for large batches its a slowdown, probably
|
|
||||||
// due to cache contention. So only reorder if the majority of nodes have few rows.
|
|
||||||
int num_small_nodes = 0;
|
|
||||||
int num_counted_nodes = 0;
|
|
||||||
for (int i = 0; i < graph->n_nodes; ++i) {
|
|
||||||
if (!is_empty(graph->nodes[i]) &&
|
|
||||||
graph->nodes[i]->op != GGML_OP_SET_ROWS) {
|
|
||||||
if (ggml_nrows(graph->nodes[i]) <= 8) {
|
|
||||||
num_small_nodes++;
|
|
||||||
}
|
|
||||||
num_counted_nodes++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (num_small_nodes < num_counted_nodes / 2) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<ggml_tensor *> new_order;
|
std::vector<ggml_tensor *> new_order;
|
||||||
std::vector<bool> used(graph->n_nodes, false);
|
std::vector<bool> used(graph->n_nodes, false);
|
||||||
std::set<ggml_tensor *> used_node_set;
|
std::set<ggml_tensor *> used_node_set;
|
||||||
|
|
@ -13769,6 +13901,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
|
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
{
|
||||||
|
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
auto device = ggml_vk_get_device(ctx->device);
|
||||||
|
// We could potentially support larger, using argsort to sort the
|
||||||
|
// whole thing. Not clear if this is needed.
|
||||||
|
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
|
||||||
|
if (min_pipeline >= num_topk_pipelines ||
|
||||||
|
!device->pipeline_topk_f32[min_pipeline]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
|
@ -13786,6 +13934,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
||||||
|
case GGML_OP_CUMSUM:
|
||||||
|
{
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
auto device = ggml_vk_get_device(ctx->device);
|
||||||
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
||||||
|
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
case GGML_OP_COUNT_EQUAL:
|
case GGML_OP_COUNT_EQUAL:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
|
|
@ -14432,10 +14589,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
||||||
} else if (tensor->op == GGML_OP_ARGSORT) {
|
} else if (tensor->op == GGML_OP_ARGSORT) {
|
||||||
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
|
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
|
||||||
|
} else if (tensor->op == GGML_OP_TOP_K) {
|
||||||
|
tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
|
||||||
} else if (tensor->op == GGML_OP_SUM) {
|
} else if (tensor->op == GGML_OP_SUM) {
|
||||||
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
||||||
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
||||||
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
|
||||||
|
} else if (tensor->op == GGML_OP_CUMSUM) {
|
||||||
|
tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
|
||||||
} else if (tensor->op == GGML_OP_MEAN) {
|
} else if (tensor->op == GGML_OP_MEAN) {
|
||||||
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
|
||||||
} else if (tensor->op == GGML_OP_ARGMAX) {
|
} else if (tensor->op == GGML_OP_ARGMAX) {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
#include "sum_rows.glsl"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_basic : enable
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||||
|
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||||
|
|
||||||
|
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||||
|
|
||||||
|
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||||
|
shared FLOAT_TYPE last_sum;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||||
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
|
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||||
|
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||||
|
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||||
|
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||||
|
|
||||||
|
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||||
|
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||||
|
|
||||||
|
uint subgroup_id = tid / SUBGROUP_SIZE;
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
last_sum = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint col = tid;
|
||||||
|
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
|
||||||
|
for (int i = 0; i < num_iter; ++i) {
|
||||||
|
FLOAT_TYPE v = 0;
|
||||||
|
if (col < p.n_cols) {
|
||||||
|
v = FLOAT_TYPE(data_a[src_idx + col]);
|
||||||
|
}
|
||||||
|
v = subgroupInclusiveAdd(v);
|
||||||
|
|
||||||
|
// Store the largest partial sum for each subgroup, then add the partials for all
|
||||||
|
// lower subgroups and the final partial sum from the previous iteration.
|
||||||
|
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
|
||||||
|
partial[subgroup_id] = v;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
for (int j = 0; j < subgroup_id; ++j) {
|
||||||
|
v += partial[j];
|
||||||
|
}
|
||||||
|
v += last_sum;
|
||||||
|
barrier();
|
||||||
|
if (tid == BLOCK_SIZE - 1) {
|
||||||
|
last_sum = v;
|
||||||
|
}
|
||||||
|
if (col < p.n_cols) {
|
||||||
|
data_d[dst_idx + col] = D_TYPE(v);
|
||||||
|
}
|
||||||
|
col += BLOCK_SIZE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#include "types.glsl"
|
#include "types.glsl"
|
||||||
|
#include "sum_rows.glsl"
|
||||||
|
|
||||||
#extension GL_EXT_control_flow_attributes : enable
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
|
@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
|
||||||
{
|
|
||||||
uint n_cols;
|
|
||||||
uint ne01, ne02;
|
|
||||||
uint nb01, nb02, nb03;
|
|
||||||
uint nb11, nb12, nb13;
|
|
||||||
float weight;
|
|
||||||
uint misalign_offsets;
|
|
||||||
uint ne0_12mp, ne0_12L;
|
|
||||||
uint ne0_1mp, ne0_1L;
|
|
||||||
} p;
|
|
||||||
|
|
||||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
|
||||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
|
||||||
|
|
||||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
|
||||||
uint fastdiv(uint n, uint mp, uint L) {
|
|
||||||
uint msbs, lsbs;
|
|
||||||
// msbs = mulhi(n, mp)
|
|
||||||
umulExtended(n, mp, msbs, lsbs);
|
|
||||||
return (msbs + n) >> L;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
|
||||||
|
// vk_op_sum_rows_push_constants
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
uint n_cols;
|
||||||
|
uint ne01, ne02;
|
||||||
|
uint nb01, nb02, nb03;
|
||||||
|
uint nb11, nb12, nb13;
|
||||||
|
float weight;
|
||||||
|
uint misalign_offsets;
|
||||||
|
uint ne0_12mp, ne0_12L;
|
||||||
|
uint ne0_1mp, ne0_1L;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||||
|
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||||
|
|
||||||
|
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||||
|
uint fastdiv(uint n, uint mp, uint L) {
|
||||||
|
uint msbs, lsbs;
|
||||||
|
// msbs = mulhi(n, mp)
|
||||||
|
umulExtended(n, mp, msbs, lsbs);
|
||||||
|
return (msbs + n) >> L;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
#version 450
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||||
|
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
// Input can either be the source (A) or intermediate values (S).
|
||||||
|
// Similarly, output can be either destination (D) or intermediate values (S).
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {int data_d[];};
|
||||||
|
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint orig_ncols;
|
||||||
|
uint ncols_input;
|
||||||
|
uint ncols_output;
|
||||||
|
uint nrows;
|
||||||
|
uint first_pass;
|
||||||
|
uint last_pass;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
// pairs of (gid, value)
|
||||||
|
shared ivec2 dst_row[BLOCK_SIZE];
|
||||||
|
|
||||||
|
void topk(bool needs_bounds_check, const uint row) {
|
||||||
|
const int col = int(gl_LocalInvocationID.x);
|
||||||
|
|
||||||
|
// initialize indices
|
||||||
|
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||||
|
if (p.first_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_input;
|
||||||
|
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols;
|
||||||
|
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dst_row[col] = ivec2(p.orig_ncols, 0);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
if (p.ncols_output == 1) {
|
||||||
|
// Fast path for single output - just do a max reduction
|
||||||
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||||
|
if (col < s) {
|
||||||
|
ivec2 a = dst_row[col];
|
||||||
|
ivec2 b = dst_row[col + s];
|
||||||
|
if (a.x >= p.orig_ncols ||
|
||||||
|
b.x < p.orig_ncols && b.y > a.y) {
|
||||||
|
dst_row[col] = b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// bitonic sort on this group of elements
|
||||||
|
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
|
||||||
|
for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
|
||||||
|
uint num_inner_loop_iters = outer_idx + 1;
|
||||||
|
for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
|
||||||
|
const int ixj = int(col ^ j);
|
||||||
|
|
||||||
|
int idx_0 = (col & k) == 0 ? col : ixj;
|
||||||
|
int idx_1 = (col & k) == 0 ? ixj : col;
|
||||||
|
|
||||||
|
ivec2 sh_idx_0 = dst_row[idx_0];
|
||||||
|
ivec2 sh_idx_1 = dst_row[idx_1];
|
||||||
|
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
|
||||||
|
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
|
||||||
|
|
||||||
|
if ((idx_0_oob ||
|
||||||
|
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
|
||||||
|
dst_row[idx_0] = sh_idx_1;
|
||||||
|
dst_row[idx_1] = sh_idx_0;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
||||||
|
if (p.last_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_output;
|
||||||
|
data_d[row_offset + col] = dst_row[col].x;
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
||||||
|
data_t[row_offset + col] = dst_row[col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
// Fast path for fully occupied workgroups
|
||||||
|
if ((p.ncols_input % BLOCK_SIZE) == 0) {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
topk(false, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
topk(true, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,199 @@
|
||||||
|
#version 450
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_debug_printf : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_basic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||||
|
layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
|
||||||
|
layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
// Input can either be the source (A) or intermediate values (S).
|
||||||
|
// Similarly, output can be either destination (D) or intermediate values (S).
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {int data_d[];};
|
||||||
|
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint orig_ncols;
|
||||||
|
uint ncols_input;
|
||||||
|
uint ncols_output;
|
||||||
|
uint nrows;
|
||||||
|
uint first_pass;
|
||||||
|
uint last_pass;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
// pairs of (gid, value)
|
||||||
|
shared ivec2 dst_row[BLOCK_SIZE];
|
||||||
|
|
||||||
|
shared int counts[SUBGROUP_SIZE];
|
||||||
|
shared int sh_min_idx;
|
||||||
|
shared uint sh_total;
|
||||||
|
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||||
|
|
||||||
|
// Map float values to uint such that comparisons still work.
|
||||||
|
// Positive values set the high bit, negative values are inverted.
|
||||||
|
// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
|
||||||
|
uint f2ui(float x) {
|
||||||
|
uint y = floatBitsToUint(x);
|
||||||
|
if ((y & 0x80000000) != 0) {
|
||||||
|
y ^= ~0;
|
||||||
|
} else {
|
||||||
|
y |= 0x80000000;
|
||||||
|
}
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
void topk(const uint row) {
|
||||||
|
const int tid = int(gl_LocalInvocationID.x);
|
||||||
|
|
||||||
|
// initialize indices
|
||||||
|
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||||
|
if (p.first_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_input;
|
||||||
|
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols;
|
||||||
|
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
if (p.ncols_output == 1) {
|
||||||
|
// Fast path for single output - just do a max reduction
|
||||||
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||||
|
if (tid < s) {
|
||||||
|
ivec2 a = dst_row[tid];
|
||||||
|
ivec2 b = dst_row[tid + s];
|
||||||
|
if (a.x >= p.orig_ncols ||
|
||||||
|
b.x < p.orig_ncols && b.y > a.y) {
|
||||||
|
dst_row[tid] = b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Do an N-ary search to find the K-th largest value.
|
||||||
|
// We remap the float values to be comparable as unsigned integers,
|
||||||
|
// and split the range into 2^N smaller ranges where N is the
|
||||||
|
// subgroup size. Count how many values are in each range, if the K-th
|
||||||
|
// largest value is in the middle of one of thee ranges then repeat
|
||||||
|
// and split again.
|
||||||
|
|
||||||
|
// Mask is the current set of bits we're searching. Shift is the LSB index.
|
||||||
|
int shift = 32 - SUBGROUP_SIZE_LOG2;
|
||||||
|
uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
|
||||||
|
|
||||||
|
// The current range.
|
||||||
|
uint range_min = 0;
|
||||||
|
uint range_max = 0xFF800000;
|
||||||
|
// How many are above the current range, and how many we need to find.
|
||||||
|
uint total = 0;
|
||||||
|
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
|
||||||
|
|
||||||
|
while (mask != 0) {
|
||||||
|
barrier();
|
||||||
|
// Initialize bucket counts to zero.
|
||||||
|
if (tid < SUBGROUP_SIZE) {
|
||||||
|
counts[tid] = 0;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
// Count how many values are in each bucket.
|
||||||
|
if (tid < p.ncols_input) {
|
||||||
|
float y = intBitsToFloat(dst_row[tid].y);
|
||||||
|
uint fy = f2ui(y);
|
||||||
|
if (fy >= range_min && fy < range_max) {
|
||||||
|
uint bucket = (fy & mask) >> shift;
|
||||||
|
atomicAdd(counts[bucket], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
// On the first subgroup, do a scan to count (from the top down) how
|
||||||
|
// many elements are in the top N buckets. Find the index of the first
|
||||||
|
// that is over the limit. Copy it to the other invocations through
|
||||||
|
// shared memory.
|
||||||
|
if (tid < SUBGROUP_SIZE) {
|
||||||
|
uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
|
||||||
|
partial_sum = subgroupInclusiveAdd(partial_sum) + total;
|
||||||
|
uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
|
||||||
|
if (tid == t) {
|
||||||
|
sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
|
||||||
|
sh_total = partial_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
int min_idx = sh_min_idx;
|
||||||
|
total = sh_total;
|
||||||
|
|
||||||
|
// Update the range, and break if we've found the K-th largest.
|
||||||
|
range_max = range_min + ((min_idx + 1) << shift);
|
||||||
|
range_min = range_min + (min_idx << shift);
|
||||||
|
|
||||||
|
if (total == p.ncols_output) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
total -= counts[min_idx];
|
||||||
|
mask >>= SUBGROUP_SIZE_LOG2;
|
||||||
|
shift -= SUBGROUP_SIZE_LOG2;
|
||||||
|
if (shift < 0) {
|
||||||
|
shift = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ivec2 v = dst_row[tid];
|
||||||
|
|
||||||
|
// We need to compact these values to the start of the dst_row array.
|
||||||
|
// Have each subgroup count how many items it'll store, so other
|
||||||
|
// subgroups can compute their base offset.
|
||||||
|
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
|
||||||
|
uvec4 b = subgroupBallot(top);
|
||||||
|
uint bit_count = subgroupBallotBitCount(b);
|
||||||
|
if ((tid % SUBGROUP_SIZE) == 0) {
|
||||||
|
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint out_idx = 0;
|
||||||
|
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
||||||
|
if (i < tid / SUBGROUP_SIZE) {
|
||||||
|
out_idx += offset_partials[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
|
||||||
|
if (top) {
|
||||||
|
// TODO: Copy directly to the output?
|
||||||
|
dst_row[out_idx + bit_count_ex] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
||||||
|
if (p.last_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_output;
|
||||||
|
data_d[row_offset + tid] = dst_row[tid].x;
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
||||||
|
data_t[row_offset + tid] = dst_row[tid];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
topk(row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -913,9 +913,13 @@ void process_shaders() {
|
||||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||||
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
|
||||||
|
string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||||
|
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
for (std::string dim_str : {"", "_3d"}) {
|
for (std::string dim_str : {"", "_3d"}) {
|
||||||
for (bool bda : {false, true}) {
|
for (bool bda : {false, true}) {
|
||||||
|
|
|
||||||
|
|
@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"ARANGE",
|
"ARANGE",
|
||||||
"TIMESTEP_EMBEDDING",
|
"TIMESTEP_EMBEDDING",
|
||||||
"ARGSORT",
|
"ARGSORT",
|
||||||
|
"TOP_K",
|
||||||
"LEAKY_RELU",
|
"LEAKY_RELU",
|
||||||
"TRI",
|
"TRI",
|
||||||
"FILL",
|
"FILL",
|
||||||
|
|
@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"GLU",
|
"GLU",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
|
@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"arange(start, stop, step)",
|
"arange(start, stop, step)",
|
||||||
"timestep_embedding(timesteps, dim, max_period)",
|
"timestep_embedding(timesteps, dim, max_period)",
|
||||||
"argsort(x)",
|
"argsort(x)",
|
||||||
|
"top_k(x)",
|
||||||
"leaky_relu(x)",
|
"leaky_relu(x)",
|
||||||
"tri(x)",
|
"tri(x)",
|
||||||
"fill(x, c)",
|
"fill(x, c)",
|
||||||
|
|
@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"glu(x)",
|
"glu(x)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
|
@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_arange
|
|
||||||
|
|
||||||
struct ggml_tensor * ggml_arange(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
float start,
|
|
||||||
float stop,
|
|
||||||
float step) {
|
|
||||||
GGML_ASSERT(stop > start);
|
|
||||||
|
|
||||||
const int64_t steps = (int64_t) ceilf((stop - start) / step);
|
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
|
|
||||||
|
|
||||||
ggml_set_op_params_f32(result, 0, start);
|
|
||||||
ggml_set_op_params_f32(result, 1, stop);
|
|
||||||
ggml_set_op_params_f32(result, 2, step);
|
|
||||||
|
|
||||||
result->op = GGML_OP_ARANGE;
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ggml_timestep_embedding
|
// ggml_timestep_embedding
|
||||||
|
|
||||||
struct ggml_tensor * ggml_timestep_embedding(
|
struct ggml_tensor * ggml_timestep_embedding(
|
||||||
|
|
@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_sort_order order) {
|
enum ggml_sort_order order) {
|
||||||
GGML_ASSERT(a->ne[0] <= INT32_MAX);
|
GGML_ASSERT(a->ne[0] <= INT32_MAX);
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
|
||||||
|
|
||||||
ggml_set_op_params_i32(result, 0, (int32_t) order);
|
ggml_set_op_params_i32(result, 0, (int32_t) order);
|
||||||
|
|
@ -5149,6 +5130,24 @@ struct ggml_tensor * ggml_argsort(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_argsort_top_k
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_argsort_top_k(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int k) {
|
||||||
|
GGML_ASSERT(a->ne[0] >= k);
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
|
||||||
|
|
||||||
|
result = ggml_view_4d(ctx, result,
|
||||||
|
k, result->ne[1], result->ne[2], result->ne[3],
|
||||||
|
result->nb[1], result->nb[2], result->nb[3],
|
||||||
|
0);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_top_k
|
// ggml_top_k
|
||||||
|
|
||||||
struct ggml_tensor * ggml_top_k(
|
struct ggml_tensor * ggml_top_k(
|
||||||
|
|
@ -5157,12 +5156,32 @@ struct ggml_tensor * ggml_top_k(
|
||||||
int k) {
|
int k) {
|
||||||
GGML_ASSERT(a->ne[0] >= k);
|
GGML_ASSERT(a->ne[0] >= k);
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
|
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
|
||||||
|
|
||||||
result = ggml_view_4d(ctx, result,
|
result->op = GGML_OP_TOP_K;
|
||||||
k, result->ne[1], result->ne[2], result->ne[3],
|
result->src[0] = a;
|
||||||
result->nb[1], result->nb[2], result->nb[3],
|
|
||||||
0);
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_arange
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_arange(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
float start,
|
||||||
|
float stop,
|
||||||
|
float step) {
|
||||||
|
GGML_ASSERT(stop > start);
|
||||||
|
|
||||||
|
const int64_t steps = (int64_t) ceilf((stop - start) / step);
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
|
||||||
|
|
||||||
|
ggml_set_op_params_f32(result, 0, start);
|
||||||
|
ggml_set_op_params_f32(result, 1, stop);
|
||||||
|
ggml_set_op_params_f32(result, 2, step);
|
||||||
|
|
||||||
|
result->op = GGML_OP_ARANGE;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import struct
|
import struct
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
|
@ -372,8 +373,10 @@ class GGUFWriter:
|
||||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||||
raw_dtype: GGMLQuantizationType | None = None,
|
raw_dtype: GGMLQuantizationType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.endianess == GGUFEndian.BIG:
|
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
|
||||||
tensor.byteswap(inplace=True)
|
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
|
||||||
|
# Don't byteswap inplace since lazy copies cannot handle it
|
||||||
|
tensor = tensor.byteswap(inplace=False)
|
||||||
if self.use_temp_file and self.temp_file is None:
|
if self.use_temp_file and self.temp_file is None:
|
||||||
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
|
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
|
||||||
fp.seek(0)
|
fp.seek(0)
|
||||||
|
|
@ -399,8 +402,10 @@ class GGUFWriter:
|
||||||
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
|
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
|
||||||
assert self.fout is not None
|
assert self.fout is not None
|
||||||
|
|
||||||
if self.endianess == GGUFEndian.BIG:
|
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
|
||||||
tensor.byteswap(inplace=True)
|
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
|
||||||
|
# Don't byteswap inplace since lazy copies cannot handle it
|
||||||
|
tensor = tensor.byteswap(inplace=False)
|
||||||
|
|
||||||
file_id = -1
|
file_id = -1
|
||||||
for i, tensors in enumerate(self.tensors):
|
for i, tensors in enumerate(self.tensors):
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ vendor = {
|
||||||
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||||
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||||
|
|
||||||
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h",
|
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h",
|
||||||
}
|
}
|
||||||
|
|
||||||
for url, filename in vendor.items():
|
for url, filename in vendor.items():
|
||||||
|
|
|
||||||
|
|
@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
// organize experts into n_expert_groups
|
// organize experts into n_expert_groups
|
||||||
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||||
|
|
||||||
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
||||||
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
|
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
|
||||||
|
|
||||||
// get top n_group_used expert groups
|
// get top n_group_used expert groups
|
||||||
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
|
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
|
||||||
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
|
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
|
||||||
|
|
||||||
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
||||||
cb(expert_groups, "ffn_moe_group_topk", il);
|
cb(expert_groups, "ffn_moe_group_topk", il);
|
||||||
|
|
||||||
// mask out the other groups
|
// mask out the other groups
|
||||||
|
|
@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
}
|
}
|
||||||
|
|
||||||
// select experts
|
// select experts
|
||||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||||
cb(selected_experts, "ffn_moe_topk", il);
|
cb(selected_experts, "ffn_moe_topk", il);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
||||||
size_t nels = ggml_nelements(tensor);
|
size_t nels = ggml_nelements(tensor);
|
||||||
|
|
@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) {
|
||||||
return mse_a_b / mse_a_0;
|
return mse_a_b / mse_a_0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
|
||||||
|
static double jdst(const int32_t * a, const int32_t * b, size_t n) {
|
||||||
|
std::unordered_map<int32_t, size_t> set_a;
|
||||||
|
std::unordered_map<int32_t, size_t> set_b;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
set_a[a[i]]++;
|
||||||
|
set_b[b[i]]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t diff = 0;
|
||||||
|
|
||||||
|
for (const auto & p : set_a) {
|
||||||
|
const int64_t na = p.second;
|
||||||
|
const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0;
|
||||||
|
|
||||||
|
diff += std::abs(na - nb);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & p : set_b) {
|
||||||
|
if (set_a.find(p.first) == set_a.end()) {
|
||||||
|
diff += p.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (double) diff / (2*n);
|
||||||
|
}
|
||||||
|
|
||||||
// maximum absolute asymmetry between a and b
|
// maximum absolute asymmetry between a and b
|
||||||
// asymmetry: (a - b) / (a + b)
|
// asymmetry: (a - b) / (a + b)
|
||||||
// This is more stable than relative error if one of the values fluctuates towards zero.
|
// This is more stable than relative error if one of the values fluctuates towards zero.
|
||||||
|
|
@ -1051,6 +1080,14 @@ struct test_case {
|
||||||
return 1e-4;
|
return 1e-4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual double max_err() {
|
||||||
|
return max_nmse_err();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual double err(const float * a, const float * b, size_t n) {
|
||||||
|
return nmse(a, b, n);
|
||||||
|
}
|
||||||
|
|
||||||
virtual float grad_eps() {
|
virtual float grad_eps() {
|
||||||
return 1e-1f;
|
return 1e-1f;
|
||||||
}
|
}
|
||||||
|
|
@ -1257,16 +1294,16 @@ struct test_case {
|
||||||
// compare
|
// compare
|
||||||
struct callback_userdata {
|
struct callback_userdata {
|
||||||
bool ok;
|
bool ok;
|
||||||
double max_err;
|
test_case * tc;
|
||||||
ggml_backend_t backend1;
|
ggml_backend_t backend1;
|
||||||
ggml_backend_t backend2;
|
ggml_backend_t backend2;
|
||||||
};
|
};
|
||||||
|
|
||||||
callback_userdata ud {
|
callback_userdata ud {
|
||||||
true,
|
true,
|
||||||
max_nmse_err(),
|
this,
|
||||||
backend1,
|
backend1,
|
||||||
backend2
|
backend2,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
|
auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
|
||||||
|
|
@ -1314,9 +1351,9 @@ struct test_case {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double err = nmse(f1.data(), f2.data(), f1.size());
|
double err = ud->tc->err(f1.data(), f2.data(), f1.size());
|
||||||
if (err > ud->max_err) {
|
if (err > ud->tc->max_err()) {
|
||||||
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
|
printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err());
|
||||||
//for (int i = 0; i < (int) f1.size(); i++) {
|
//for (int i = 0; i < (int) f1.size(); i++) {
|
||||||
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||||
//}
|
//}
|
||||||
|
|
@ -4943,7 +4980,71 @@ struct test_argsort : public test_case {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct test_topk_moe: public test_case {
|
// GGML_OP_TOP_K
|
||||||
|
struct test_top_k : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
const std::array<int64_t, 4> ne;
|
||||||
|
const int k;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR3(type, ne, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_top_k(ggml_type type = GGML_TYPE_F32,
|
||||||
|
std::array<int64_t, 4> ne = {16, 10, 10, 10},
|
||||||
|
int k = 4)
|
||||||
|
: type(type), ne(ne), k(k) {}
|
||||||
|
|
||||||
|
double max_err() override {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
double err(const float * a, const float * b, size_t n) override {
|
||||||
|
std::vector<int32_t> ia(n);
|
||||||
|
std::vector<int32_t> ib(n);
|
||||||
|
|
||||||
|
double diff = 0.0f;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n; i++) {
|
||||||
|
ia[i] = (int32_t) a[i];
|
||||||
|
ib[i] = (int32_t) b[i];
|
||||||
|
|
||||||
|
// penalize the result if the data is not integer valued
|
||||||
|
diff += std::fabs(a[i] - ia[i]);
|
||||||
|
diff += std::fabs(b[i] - ib[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return diff + jdst(ia.data(), ib.data(), n);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
|
ggml_tensor * out = ggml_top_k(ctx, a, k);
|
||||||
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize_tensors(ggml_context * ctx) override {
|
||||||
|
std::random_device rd;
|
||||||
|
std::default_random_engine rng(rd());
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
// initialize with unique values to avoid ties
|
||||||
|
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||||
|
std::vector<float> data(t->ne[0]);
|
||||||
|
for (int i = 0; i < t->ne[0]; i++) {
|
||||||
|
data[i] = i;
|
||||||
|
}
|
||||||
|
std::shuffle(data.begin(), data.end(), rng);
|
||||||
|
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct test_topk_moe : public test_case {
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
const int n_expert_used;
|
const int n_expert_used;
|
||||||
const bool with_norm;
|
const bool with_norm;
|
||||||
|
|
@ -4976,7 +5077,7 @@ struct test_topk_moe: public test_case {
|
||||||
|
|
||||||
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||||
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
|
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
|
||||||
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
|
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||||
|
|
||||||
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||||
|
|
||||||
|
|
@ -7534,6 +7635,31 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 20; ++i) {
|
||||||
|
for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) {
|
||||||
|
if (k <= 1<<i) {
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int k : {1, 2, 3, 7, 15}) {
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k));
|
||||||
|
}
|
||||||
|
|
||||||
|
// exhaustive top_k tests
|
||||||
|
//for (int i = 1; i < 9999; ++i) {
|
||||||
|
// test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
|
||||||
|
//}
|
||||||
|
|
||||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
|
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
|
||||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
|
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
|
||||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
|
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
|
||||||
|
|
@ -7914,6 +8040,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
||||||
|
for (auto k : {1, 10, 40}) {
|
||||||
|
for (auto nrows : {1, 16}) {
|
||||||
|
for (auto cols : {k, 1000, 65000, 200000}) {
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return test_cases;
|
return test_cases;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,13 +31,16 @@ if (LLAMA_BUILD_BORINGSSL)
|
||||||
|
|
||||||
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")
|
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")
|
||||||
|
|
||||||
include(FetchContent)
|
set(BORINGSSL_ARGS
|
||||||
FetchContent_Declare(
|
|
||||||
boringssl
|
|
||||||
GIT_REPOSITORY ${BORINGSSL_GIT}
|
GIT_REPOSITORY ${BORINGSSL_GIT}
|
||||||
GIT_TAG ${BORINGSSL_VERSION}
|
GIT_TAG ${BORINGSSL_VERSION}
|
||||||
PATCH_COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_SOURCE_DIR}/patch-boringssl.cmake"
|
|
||||||
)
|
)
|
||||||
|
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
|
||||||
|
list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
include(FetchContent)
|
||||||
|
FetchContent_Declare(boringssl ${BORINGSSL_ARGS})
|
||||||
|
|
||||||
set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
|
set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
|
||||||
set(SAVED_BUILD_TESTING ${BUILD_TESTING})
|
set(SAVED_BUILD_TESTING ${BUILD_TESTING})
|
||||||
|
|
@ -45,7 +48,15 @@ if (LLAMA_BUILD_BORINGSSL)
|
||||||
set(BUILD_SHARED_LIBS OFF)
|
set(BUILD_SHARED_LIBS OFF)
|
||||||
set(BUILD_TESTING OFF)
|
set(BUILD_TESTING OFF)
|
||||||
|
|
||||||
FetchContent_MakeAvailable(boringssl)
|
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28)
|
||||||
|
FetchContent_MakeAvailable(boringssl)
|
||||||
|
else()
|
||||||
|
FetchContent_GetProperties(boringssl)
|
||||||
|
if(NOT boringssl_POPULATED)
|
||||||
|
FetchContent_Populate(boringssl)
|
||||||
|
add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS})
|
set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS})
|
||||||
set(BUILD_TESTING ${SAVED_BUILD_TESTING})
|
set(BUILD_TESTING ${SAVED_BUILD_TESTING})
|
||||||
|
|
|
||||||
|
|
@ -1087,22 +1087,30 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
|
||||||
// Fallback implementation using thread-based timeout for other Unix systems
|
// Fallback implementation using thread-based timeout for other Unix systems
|
||||||
|
|
||||||
struct GetAddrInfoState {
|
struct GetAddrInfoState {
|
||||||
|
~GetAddrInfoState() {
|
||||||
|
if (info) { freeaddrinfo(info); }
|
||||||
|
}
|
||||||
|
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
std::condition_variable result_cv;
|
std::condition_variable result_cv;
|
||||||
bool completed = false;
|
bool completed = false;
|
||||||
int result = EAI_SYSTEM;
|
int result = EAI_SYSTEM;
|
||||||
std::string node = node;
|
std::string node;
|
||||||
std::string service = service;
|
std::string service;
|
||||||
struct addrinfo hints = hints;
|
struct addrinfo hints;
|
||||||
struct addrinfo *info = nullptr;
|
struct addrinfo *info = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Allocate on the heap, so the resolver thread can keep using the data.
|
// Allocate on the heap, so the resolver thread can keep using the data.
|
||||||
auto state = std::make_shared<GetAddrInfoState>();
|
auto state = std::make_shared<GetAddrInfoState>();
|
||||||
|
state->node = node;
|
||||||
|
state->service = service;
|
||||||
|
state->hints = *hints;
|
||||||
|
|
||||||
std::thread resolve_thread([=]() {
|
std::thread resolve_thread([state]() {
|
||||||
auto thread_result = getaddrinfo(
|
auto thread_result =
|
||||||
state->node.c_str(), state->service.c_str(), hints, &state->info);
|
getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints,
|
||||||
|
&state->info);
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(state->mutex);
|
std::lock_guard<std::mutex> lock(state->mutex);
|
||||||
state->result = thread_result;
|
state->result = thread_result;
|
||||||
|
|
@ -1120,6 +1128,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service,
|
||||||
// Operation completed within timeout
|
// Operation completed within timeout
|
||||||
resolve_thread.join();
|
resolve_thread.join();
|
||||||
*res = state->info;
|
*res = state->info;
|
||||||
|
state->info = nullptr; // Pass ownership to caller
|
||||||
return state->result;
|
return state->result;
|
||||||
} else {
|
} else {
|
||||||
// Timeout occurred
|
// Timeout occurred
|
||||||
|
|
@ -4970,7 +4979,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection,
|
||||||
if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
|
if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
|
||||||
|
|
||||||
// Prepare additional headers
|
// Prepare additional headers
|
||||||
if (close_connection || req.get_header_value("Connection") == "close") {
|
if (close_connection || req.get_header_value("Connection") == "close" ||
|
||||||
|
400 <= res.status) { // Don't leave connections open after errors
|
||||||
res.set_header("Connection", "close");
|
res.set_header("Connection", "close");
|
||||||
} else {
|
} else {
|
||||||
std::string s = "timeout=";
|
std::string s = "timeout=";
|
||||||
|
|
@ -5173,7 +5183,11 @@ bool Server::read_content_core(
|
||||||
size_t /*len*/) { return receiver(buf, n); };
|
size_t /*len*/) { return receiver(buf, n); };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (req.method == "DELETE" && !req.has_header("Content-Length")) {
|
// RFC 7230 Section 3.3.3: If this is a request message and none of the above
|
||||||
|
// are true (no Transfer-Encoding and no Content-Length), then the message
|
||||||
|
// body length is zero (no message body is present).
|
||||||
|
if (!req.has_header("Content-Length") &&
|
||||||
|
!detail::is_chunked_transfer_encoding(req.headers)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -5681,8 +5695,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||||
|
|
||||||
// Check if the request URI doesn't exceed the limit
|
// Check if the request URI doesn't exceed the limit
|
||||||
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
|
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
|
||||||
Headers dummy;
|
|
||||||
detail::read_headers(strm, dummy);
|
|
||||||
res.status = StatusCode::UriTooLong_414;
|
res.status = StatusCode::UriTooLong_414;
|
||||||
output_error_log(Error::ExceedUriMaxLength, &req);
|
output_error_log(Error::ExceedUriMaxLength, &req);
|
||||||
return write_response(strm, close_connection, req, res);
|
return write_response(strm, close_connection, req, res);
|
||||||
|
|
@ -6666,11 +6678,13 @@ bool ClientImpl::write_request(Stream &strm, Request &req,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Response> ClientImpl::send_with_content_provider(
|
std::unique_ptr<Response>
|
||||||
|
ClientImpl::send_with_content_provider_and_receiver(
|
||||||
Request &req, const char *body, size_t content_length,
|
Request &req, const char *body, size_t content_length,
|
||||||
ContentProvider content_provider,
|
ContentProvider content_provider,
|
||||||
ContentProviderWithoutLength content_provider_without_length,
|
ContentProviderWithoutLength content_provider_without_length,
|
||||||
const std::string &content_type, Error &error) {
|
const std::string &content_type, ContentReceiver content_receiver,
|
||||||
|
Error &error) {
|
||||||
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
|
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
|
||||||
|
|
||||||
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
||||||
|
|
@ -6743,15 +6757,24 @@ std::unique_ptr<Response> ClientImpl::send_with_content_provider(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (content_receiver) {
|
||||||
|
req.content_receiver =
|
||||||
|
[content_receiver](const char *data, size_t data_length,
|
||||||
|
size_t /*offset*/, size_t /*total_length*/) {
|
||||||
|
return content_receiver(data, data_length);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
auto res = detail::make_unique<Response>();
|
auto res = detail::make_unique<Response>();
|
||||||
return send(req, *res, error) ? std::move(res) : nullptr;
|
return send(req, *res, error) ? std::move(res) : nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::send_with_content_provider(
|
Result ClientImpl::send_with_content_provider_and_receiver(
|
||||||
const std::string &method, const std::string &path, const Headers &headers,
|
const std::string &method, const std::string &path, const Headers &headers,
|
||||||
const char *body, size_t content_length, ContentProvider content_provider,
|
const char *body, size_t content_length, ContentProvider content_provider,
|
||||||
ContentProviderWithoutLength content_provider_without_length,
|
ContentProviderWithoutLength content_provider_without_length,
|
||||||
const std::string &content_type, UploadProgress progress) {
|
const std::string &content_type, ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
Request req;
|
Request req;
|
||||||
req.method = method;
|
req.method = method;
|
||||||
req.headers = headers;
|
req.headers = headers;
|
||||||
|
|
@ -6763,9 +6786,10 @@ Result ClientImpl::send_with_content_provider(
|
||||||
|
|
||||||
auto error = Error::Success;
|
auto error = Error::Success;
|
||||||
|
|
||||||
auto res = send_with_content_provider(
|
auto res = send_with_content_provider_and_receiver(
|
||||||
req, body, content_length, std::move(content_provider),
|
req, body, content_length, std::move(content_provider),
|
||||||
std::move(content_provider_without_length), content_type, error);
|
std::move(content_provider_without_length), content_type,
|
||||||
|
std::move(content_receiver), error);
|
||||||
|
|
||||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||||
return Result{std::move(res), error, std::move(req.headers), last_ssl_error_,
|
return Result{std::move(res), error, std::move(req.headers), last_ssl_error_,
|
||||||
|
|
@ -7094,6 +7118,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length,
|
||||||
content_type, progress);
|
content_type, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Post(const std::string &path, size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return Post(path, Headers(), content_length, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
|
|
||||||
Result ClientImpl::Post(const std::string &path,
|
Result ClientImpl::Post(const std::string &path,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
|
|
@ -7102,6 +7135,15 @@ Result ClientImpl::Post(const std::string &path,
|
||||||
progress);
|
progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Post(const std::string &path,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return Post(path, Headers(), std::move(content_provider), content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
|
|
||||||
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
const Params ¶ms) {
|
const Params ¶ms) {
|
||||||
auto query = detail::params_to_query_str(params);
|
auto query = detail::params_to_query_str(params);
|
||||||
|
|
@ -7142,17 +7184,18 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
const char *body, size_t content_length,
|
const char *body, size_t content_length,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("POST", path, headers, body, content_length,
|
return send_with_content_provider_and_receiver(
|
||||||
nullptr, nullptr, content_type, progress);
|
"POST", path, headers, body, content_length, nullptr, nullptr,
|
||||||
|
content_type, nullptr, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
const std::string &body,
|
const std::string &body,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("POST", path, headers, body.data(),
|
return send_with_content_provider_and_receiver(
|
||||||
body.size(), nullptr, nullptr, content_type,
|
"POST", path, headers, body.data(), body.size(), nullptr, nullptr,
|
||||||
progress);
|
content_type, nullptr, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
|
|
@ -7160,18 +7203,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
ContentProvider content_provider,
|
ContentProvider content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("POST", path, headers, nullptr,
|
return send_with_content_provider_and_receiver(
|
||||||
content_length, std::move(content_provider),
|
"POST", path, headers, nullptr, content_length,
|
||||||
nullptr, content_type, progress);
|
std::move(content_provider), nullptr, content_type, nullptr, progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
|
size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
DownloadProgress progress) {
|
||||||
|
return send_with_content_provider_and_receiver(
|
||||||
|
"POST", path, headers, nullptr, content_length,
|
||||||
|
std::move(content_provider), nullptr, content_type,
|
||||||
|
std::move(content_receiver), std::move(progress));
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr,
|
return send_with_content_provider_and_receiver(
|
||||||
std::move(content_provider), content_type,
|
"POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
|
||||||
progress);
|
content_type, nullptr, progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
DownloadProgress progress) {
|
||||||
|
return send_with_content_provider_and_receiver(
|
||||||
|
"POST", path, headers, nullptr, 0, nullptr, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), std::move(progress));
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
|
|
@ -7181,10 +7246,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
const auto &boundary = detail::make_multipart_data_boundary();
|
const auto &boundary = detail::make_multipart_data_boundary();
|
||||||
const auto &content_type =
|
const auto &content_type =
|
||||||
detail::serialize_multipart_formdata_get_content_type(boundary);
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
||||||
return send_with_content_provider(
|
return send_with_content_provider_and_receiver(
|
||||||
"POST", path, headers, nullptr, 0, nullptr,
|
"POST", path, headers, nullptr, 0, nullptr,
|
||||||
get_multipart_content_provider(boundary, items, provider_items),
|
get_multipart_content_provider(boundary, items, provider_items),
|
||||||
content_type, progress);
|
content_type, nullptr, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
||||||
|
|
@ -7246,6 +7311,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length,
|
||||||
content_type, progress);
|
content_type, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Put(const std::string &path, size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return Put(path, Headers(), content_length, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
|
|
||||||
Result ClientImpl::Put(const std::string &path,
|
Result ClientImpl::Put(const std::string &path,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
|
|
@ -7254,6 +7328,15 @@ Result ClientImpl::Put(const std::string &path,
|
||||||
progress);
|
progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Put(const std::string &path,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return Put(path, Headers(), std::move(content_provider), content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
|
|
||||||
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
const Params ¶ms) {
|
const Params ¶ms) {
|
||||||
auto query = detail::params_to_query_str(params);
|
auto query = detail::params_to_query_str(params);
|
||||||
|
|
@ -7294,17 +7377,18 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
const char *body, size_t content_length,
|
const char *body, size_t content_length,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("PUT", path, headers, body, content_length,
|
return send_with_content_provider_and_receiver(
|
||||||
nullptr, nullptr, content_type, progress);
|
"PUT", path, headers, body, content_length, nullptr, nullptr,
|
||||||
|
content_type, nullptr, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
const std::string &body,
|
const std::string &body,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("PUT", path, headers, body.data(),
|
return send_with_content_provider_and_receiver(
|
||||||
body.size(), nullptr, nullptr, content_type,
|
"PUT", path, headers, body.data(), body.size(), nullptr, nullptr,
|
||||||
progress);
|
content_type, nullptr, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
|
|
@ -7312,18 +7396,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
ContentProvider content_provider,
|
ContentProvider content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("PUT", path, headers, nullptr,
|
return send_with_content_provider_and_receiver(
|
||||||
content_length, std::move(content_provider),
|
"PUT", path, headers, nullptr, content_length,
|
||||||
nullptr, content_type, progress);
|
std::move(content_provider), nullptr, content_type, nullptr, progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
|
size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return send_with_content_provider_and_receiver(
|
||||||
|
"PUT", path, headers, nullptr, content_length,
|
||||||
|
std::move(content_provider), nullptr, content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr,
|
return send_with_content_provider_and_receiver(
|
||||||
std::move(content_provider), content_type,
|
"PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
|
||||||
progress);
|
content_type, nullptr, progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return send_with_content_provider_and_receiver(
|
||||||
|
"PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
|
|
@ -7333,10 +7439,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
const auto &boundary = detail::make_multipart_data_boundary();
|
const auto &boundary = detail::make_multipart_data_boundary();
|
||||||
const auto &content_type =
|
const auto &content_type =
|
||||||
detail::serialize_multipart_formdata_get_content_type(boundary);
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
||||||
return send_with_content_provider(
|
return send_with_content_provider_and_receiver(
|
||||||
"PUT", path, headers, nullptr, 0, nullptr,
|
"PUT", path, headers, nullptr, 0, nullptr,
|
||||||
get_multipart_content_provider(boundary, items, provider_items),
|
get_multipart_content_provider(boundary, items, provider_items),
|
||||||
content_type, progress);
|
content_type, nullptr, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
||||||
|
|
@ -7400,6 +7506,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length,
|
||||||
content_type, progress);
|
content_type, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Patch(const std::string &path, size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return Patch(path, Headers(), content_length, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
|
|
||||||
Result ClientImpl::Patch(const std::string &path,
|
Result ClientImpl::Patch(const std::string &path,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
|
|
@ -7408,6 +7523,15 @@ Result ClientImpl::Patch(const std::string &path,
|
||||||
progress);
|
progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Patch(const std::string &path,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return Patch(path, Headers(), std::move(content_provider), content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
|
|
||||||
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
const Params ¶ms) {
|
const Params ¶ms) {
|
||||||
auto query = detail::params_to_query_str(params);
|
auto query = detail::params_to_query_str(params);
|
||||||
|
|
@ -7448,18 +7572,18 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
const char *body, size_t content_length,
|
const char *body, size_t content_length,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("PATCH", path, headers, body,
|
return send_with_content_provider_and_receiver(
|
||||||
content_length, nullptr, nullptr,
|
"PATCH", path, headers, body, content_length, nullptr, nullptr,
|
||||||
content_type, progress);
|
content_type, nullptr, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
const std::string &body,
|
const std::string &body,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("PATCH", path, headers, body.data(),
|
return send_with_content_provider_and_receiver(
|
||||||
body.size(), nullptr, nullptr, content_type,
|
"PATCH", path, headers, body.data(), body.size(), nullptr, nullptr,
|
||||||
progress);
|
content_type, nullptr, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
|
|
@ -7467,18 +7591,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
ContentProvider content_provider,
|
ContentProvider content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("PATCH", path, headers, nullptr,
|
return send_with_content_provider_and_receiver(
|
||||||
content_length, std::move(content_provider),
|
"PATCH", path, headers, nullptr, content_length,
|
||||||
nullptr, content_type, progress);
|
std::move(content_provider), nullptr, content_type, nullptr, progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
|
size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return send_with_content_provider_and_receiver(
|
||||||
|
"PATCH", path, headers, nullptr, content_length,
|
||||||
|
std::move(content_provider), nullptr, content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr,
|
return send_with_content_provider_and_receiver(
|
||||||
std::move(content_provider), content_type,
|
"PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
|
||||||
progress);
|
content_type, nullptr, progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return send_with_content_provider_and_receiver(
|
||||||
|
"PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
|
|
@ -7488,10 +7634,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
const auto &boundary = detail::make_multipart_data_boundary();
|
const auto &boundary = detail::make_multipart_data_boundary();
|
||||||
const auto &content_type =
|
const auto &content_type =
|
||||||
detail::serialize_multipart_formdata_get_content_type(boundary);
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
||||||
return send_with_content_provider(
|
return send_with_content_provider_and_receiver(
|
||||||
"PATCH", path, headers, nullptr, 0, nullptr,
|
"PATCH", path, headers, nullptr, 0, nullptr,
|
||||||
get_multipart_content_provider(boundary, items, provider_items),
|
get_multipart_content_provider(boundary, items, provider_items),
|
||||||
content_type, progress);
|
content_type, nullptr, progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
||||||
|
|
@ -8883,12 +9029,28 @@ Result Client::Post(const std::string &path, size_t content_length,
|
||||||
return cli_->Post(path, content_length, std::move(content_provider),
|
return cli_->Post(path, content_length, std::move(content_provider),
|
||||||
content_type, progress);
|
content_type, progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Post(const std::string &path, size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Post(path, content_length, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Post(const std::string &path,
|
Result Client::Post(const std::string &path,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return cli_->Post(path, std::move(content_provider), content_type, progress);
|
return cli_->Post(path, std::move(content_provider), content_type, progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Post(const std::string &path,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Post(path, std::move(content_provider), content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Post(const std::string &path, const Headers &headers,
|
Result Client::Post(const std::string &path, const Headers &headers,
|
||||||
size_t content_length,
|
size_t content_length,
|
||||||
ContentProvider content_provider,
|
ContentProvider content_provider,
|
||||||
|
|
@ -8897,6 +9059,15 @@ Result Client::Post(const std::string &path, const Headers &headers,
|
||||||
return cli_->Post(path, headers, content_length, std::move(content_provider),
|
return cli_->Post(path, headers, content_length, std::move(content_provider),
|
||||||
content_type, progress);
|
content_type, progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Post(const std::string &path, const Headers &headers,
|
||||||
|
size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
DownloadProgress progress) {
|
||||||
|
return cli_->Post(path, headers, content_length, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Post(const std::string &path, const Headers &headers,
|
Result Client::Post(const std::string &path, const Headers &headers,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
|
|
@ -8904,6 +9075,14 @@ Result Client::Post(const std::string &path, const Headers &headers,
|
||||||
return cli_->Post(path, headers, std::move(content_provider), content_type,
|
return cli_->Post(path, headers, std::move(content_provider), content_type,
|
||||||
progress);
|
progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Post(const std::string &path, const Headers &headers,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
DownloadProgress progress) {
|
||||||
|
return cli_->Post(path, headers, std::move(content_provider), content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Post(const std::string &path, const Params ¶ms) {
|
Result Client::Post(const std::string &path, const Params ¶ms) {
|
||||||
return cli_->Post(path, params);
|
return cli_->Post(path, params);
|
||||||
}
|
}
|
||||||
|
|
@ -8938,8 +9117,8 @@ Result Client::Post(const std::string &path, const Headers &headers,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
ContentReceiver content_receiver,
|
ContentReceiver content_receiver,
|
||||||
DownloadProgress progress) {
|
DownloadProgress progress) {
|
||||||
return cli_->Post(path, headers, body, content_type, content_receiver,
|
return cli_->Post(path, headers, body, content_type,
|
||||||
progress);
|
std::move(content_receiver), progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
Result Client::Put(const std::string &path) { return cli_->Put(path); }
|
Result Client::Put(const std::string &path) { return cli_->Put(path); }
|
||||||
|
|
@ -8976,12 +9155,28 @@ Result Client::Put(const std::string &path, size_t content_length,
|
||||||
return cli_->Put(path, content_length, std::move(content_provider),
|
return cli_->Put(path, content_length, std::move(content_provider),
|
||||||
content_type, progress);
|
content_type, progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Put(const std::string &path, size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Put(path, content_length, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Put(const std::string &path,
|
Result Client::Put(const std::string &path,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return cli_->Put(path, std::move(content_provider), content_type, progress);
|
return cli_->Put(path, std::move(content_provider), content_type, progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Put(const std::string &path,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Put(path, std::move(content_provider), content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Put(const std::string &path, const Headers &headers,
|
Result Client::Put(const std::string &path, const Headers &headers,
|
||||||
size_t content_length,
|
size_t content_length,
|
||||||
ContentProvider content_provider,
|
ContentProvider content_provider,
|
||||||
|
|
@ -8990,6 +9185,15 @@ Result Client::Put(const std::string &path, const Headers &headers,
|
||||||
return cli_->Put(path, headers, content_length, std::move(content_provider),
|
return cli_->Put(path, headers, content_length, std::move(content_provider),
|
||||||
content_type, progress);
|
content_type, progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Put(const std::string &path, const Headers &headers,
|
||||||
|
size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Put(path, headers, content_length, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Put(const std::string &path, const Headers &headers,
|
Result Client::Put(const std::string &path, const Headers &headers,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
|
|
@ -8997,6 +9201,14 @@ Result Client::Put(const std::string &path, const Headers &headers,
|
||||||
return cli_->Put(path, headers, std::move(content_provider), content_type,
|
return cli_->Put(path, headers, std::move(content_provider), content_type,
|
||||||
progress);
|
progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Put(const std::string &path, const Headers &headers,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Put(path, headers, std::move(content_provider), content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Put(const std::string &path, const Params ¶ms) {
|
Result Client::Put(const std::string &path, const Params ¶ms) {
|
||||||
return cli_->Put(path, params);
|
return cli_->Put(path, params);
|
||||||
}
|
}
|
||||||
|
|
@ -9072,12 +9284,28 @@ Result Client::Patch(const std::string &path, size_t content_length,
|
||||||
return cli_->Patch(path, content_length, std::move(content_provider),
|
return cli_->Patch(path, content_length, std::move(content_provider),
|
||||||
content_type, progress);
|
content_type, progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Patch(const std::string &path, size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Patch(path, content_length, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Patch(const std::string &path,
|
Result Client::Patch(const std::string &path,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
UploadProgress progress) {
|
UploadProgress progress) {
|
||||||
return cli_->Patch(path, std::move(content_provider), content_type, progress);
|
return cli_->Patch(path, std::move(content_provider), content_type, progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Patch(const std::string &path,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Patch(path, std::move(content_provider), content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Patch(const std::string &path, const Headers &headers,
|
Result Client::Patch(const std::string &path, const Headers &headers,
|
||||||
size_t content_length,
|
size_t content_length,
|
||||||
ContentProvider content_provider,
|
ContentProvider content_provider,
|
||||||
|
|
@ -9086,6 +9314,15 @@ Result Client::Patch(const std::string &path, const Headers &headers,
|
||||||
return cli_->Patch(path, headers, content_length, std::move(content_provider),
|
return cli_->Patch(path, headers, content_length, std::move(content_provider),
|
||||||
content_type, progress);
|
content_type, progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Patch(const std::string &path, const Headers &headers,
|
||||||
|
size_t content_length,
|
||||||
|
ContentProvider content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Patch(path, headers, content_length, std::move(content_provider),
|
||||||
|
content_type, std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Patch(const std::string &path, const Headers &headers,
|
Result Client::Patch(const std::string &path, const Headers &headers,
|
||||||
ContentProviderWithoutLength content_provider,
|
ContentProviderWithoutLength content_provider,
|
||||||
const std::string &content_type,
|
const std::string &content_type,
|
||||||
|
|
@ -9093,6 +9330,14 @@ Result Client::Patch(const std::string &path, const Headers &headers,
|
||||||
return cli_->Patch(path, headers, std::move(content_provider), content_type,
|
return cli_->Patch(path, headers, std::move(content_provider), content_type,
|
||||||
progress);
|
progress);
|
||||||
}
|
}
|
||||||
|
Result Client::Patch(const std::string &path, const Headers &headers,
|
||||||
|
ContentProviderWithoutLength content_provider,
|
||||||
|
const std::string &content_type,
|
||||||
|
ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress) {
|
||||||
|
return cli_->Patch(path, headers, std::move(content_provider), content_type,
|
||||||
|
std::move(content_receiver), progress);
|
||||||
|
}
|
||||||
Result Client::Patch(const std::string &path, const Params ¶ms) {
|
Result Client::Patch(const std::string &path, const Params ¶ms) {
|
||||||
return cli_->Patch(path, params);
|
return cli_->Patch(path, params);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@
|
||||||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||||
#define CPPHTTPLIB_HTTPLIB_H
|
#define CPPHTTPLIB_HTTPLIB_H
|
||||||
|
|
||||||
#define CPPHTTPLIB_VERSION "0.27.0"
|
#define CPPHTTPLIB_VERSION "0.28.0"
|
||||||
#define CPPHTTPLIB_VERSION_NUM "0x001B00"
|
#define CPPHTTPLIB_VERSION_NUM "0x001C00"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Platform compatibility check
|
* Platform compatibility check
|
||||||
|
|
@ -257,6 +257,7 @@ using socklen_t = int;
|
||||||
#include <netinet/in.h>
|
#include <netinet/in.h>
|
||||||
#ifdef __linux__
|
#ifdef __linux__
|
||||||
#include <resolv.h>
|
#include <resolv.h>
|
||||||
|
#undef _res // Undefine _res macro to avoid conflicts with user code (#2278)
|
||||||
#endif
|
#endif
|
||||||
#include <csignal>
|
#include <csignal>
|
||||||
#include <netinet/tcp.h>
|
#include <netinet/tcp.h>
|
||||||
|
|
@ -1421,14 +1422,18 @@ public:
|
||||||
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Params ¶ms);
|
Result Post(const std::string &path, const Params ¶ms);
|
||||||
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers);
|
Result Post(const std::string &path, const Headers &headers);
|
||||||
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, const Params ¶ms);
|
Result Post(const std::string &path, const Headers &headers, const Params ¶ms);
|
||||||
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
||||||
|
|
@ -1439,14 +1444,18 @@ public:
|
||||||
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Params ¶ms);
|
Result Put(const std::string &path, const Params ¶ms);
|
||||||
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers);
|
Result Put(const std::string &path, const Headers &headers);
|
||||||
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, const Params ¶ms);
|
Result Put(const std::string &path, const Headers &headers, const Params ¶ms);
|
||||||
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
||||||
|
|
@ -1457,14 +1466,18 @@ public:
|
||||||
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Params ¶ms);
|
Result Patch(const std::string &path, const Params ¶ms);
|
||||||
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const Params ¶ms);
|
Result Patch(const std::string &path, const Headers &headers, const Params ¶ms);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
||||||
|
|
@ -1712,17 +1725,19 @@ private:
|
||||||
template <typename ClientType> void setup_redirect_client(ClientType &client);
|
template <typename ClientType> void setup_redirect_client(ClientType &client);
|
||||||
bool handle_request(Stream &strm, Request &req, Response &res,
|
bool handle_request(Stream &strm, Request &req, Response &res,
|
||||||
bool close_connection, Error &error);
|
bool close_connection, Error &error);
|
||||||
std::unique_ptr<Response> send_with_content_provider(
|
std::unique_ptr<Response> send_with_content_provider_and_receiver(
|
||||||
Request &req, const char *body, size_t content_length,
|
Request &req, const char *body, size_t content_length,
|
||||||
ContentProvider content_provider,
|
ContentProvider content_provider,
|
||||||
ContentProviderWithoutLength content_provider_without_length,
|
ContentProviderWithoutLength content_provider_without_length,
|
||||||
const std::string &content_type, Error &error);
|
const std::string &content_type, ContentReceiver content_receiver,
|
||||||
Result send_with_content_provider(
|
Error &error);
|
||||||
|
Result send_with_content_provider_and_receiver(
|
||||||
const std::string &method, const std::string &path,
|
const std::string &method, const std::string &path,
|
||||||
const Headers &headers, const char *body, size_t content_length,
|
const Headers &headers, const char *body, size_t content_length,
|
||||||
ContentProvider content_provider,
|
ContentProvider content_provider,
|
||||||
ContentProviderWithoutLength content_provider_without_length,
|
ContentProviderWithoutLength content_provider_without_length,
|
||||||
const std::string &content_type, UploadProgress progress);
|
const std::string &content_type, ContentReceiver content_receiver,
|
||||||
|
UploadProgress progress);
|
||||||
ContentProviderWithoutLength get_multipart_content_provider(
|
ContentProviderWithoutLength get_multipart_content_provider(
|
||||||
const std::string &boundary, const UploadFormDataItems &items,
|
const std::string &boundary, const UploadFormDataItems &items,
|
||||||
const FormDataProviderItems &provider_items) const;
|
const FormDataProviderItems &provider_items) const;
|
||||||
|
|
@ -1775,14 +1790,18 @@ public:
|
||||||
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Params ¶ms);
|
Result Post(const std::string &path, const Params ¶ms);
|
||||||
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers);
|
Result Post(const std::string &path, const Headers &headers);
|
||||||
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, const Params ¶ms);
|
Result Post(const std::string &path, const Headers &headers, const Params ¶ms);
|
||||||
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
||||||
|
|
@ -1793,14 +1812,18 @@ public:
|
||||||
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Params ¶ms);
|
Result Put(const std::string &path, const Params ¶ms);
|
||||||
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers);
|
Result Put(const std::string &path, const Headers &headers);
|
||||||
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, const Params ¶ms);
|
Result Put(const std::string &path, const Headers &headers, const Params ¶ms);
|
||||||
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
||||||
|
|
@ -1811,14 +1834,18 @@ public:
|
||||||
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Params ¶ms);
|
Result Patch(const std::string &path, const Params ¶ms);
|
||||||
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers);
|
Result Patch(const std::string &path, const Headers &headers);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
||||||
|
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const Params ¶ms);
|
Result Patch(const std::string &path, const Headers &headers, const Params ¶ms);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
||||||
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
||||||
|
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
# Remove bssl
|
|
||||||
file(READ "CMakeLists.txt" content)
|
|
||||||
string(REPLACE "add_executable(bssl" "#add_executable(bssl" content "${content}")
|
|
||||||
string(REPLACE "target_link_libraries(bssl" "#target_link_libraries(bssl" content "${content}")
|
|
||||||
string(REPLACE "install(TARGETS bssl" "#install(TARGETS bssl" content "${content}")
|
|
||||||
file(WRITE "CMakeLists.txt" "${content}")
|
|
||||||
Loading…
Reference in New Issue