diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d24a4682f3..daaf0bf497 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10061,6 +10061,25 @@ class LazyTorchTensor(gguf.LazyBase): torch.uint8: np.uint8, } + # only used when byteswapping data. Only correct size is needed + _dtype_byteswap_map: dict[torch.dtype, type] = { + torch.float64: np.float64, + torch.float32: np.float32, + torch.bfloat16: np.float16, + torch.float16: np.float16, + torch.int64: np.int64, + torch.uint64: np.uint64, + torch.int32: np.int32, + torch.uint32: np.uint32, + torch.int16: np.int16, + torch.uint16: np.uint16, + torch.int8: np.int8, + torch.uint8: np.uint8, + torch.bool: np.uint8, + torch.float8_e4m3fn: np.uint8, + torch.float8_e5m2: np.uint8, + } + # used for safetensors slices # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 @@ -10104,8 +10123,14 @@ class LazyTorchTensor(gguf.LazyBase): @classmethod def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor: def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[tensor.dtype] - return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape) + numpy_dtype = cls._dtype_byteswap_map[dtype] + return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape) dtype = cls._dtype_str_map[t.dtype] shape = t.shape lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r)) @@ -10113,10 +10138,16 @@ class LazyTorchTensor(gguf.LazyBase): @classmethod def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[remote_tensor.dtype] + numpy_dtype = cls._dtype_byteswap_map[dtype] shape = remote_tensor.shape meta = cls.meta_with_dtype_and_shape(dtype, shape) - lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape)) return cast(torch.Tensor, lazy) @classmethod diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 605fcfcb9c..4dbca868bc 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -530,6 +530,7 @@ extern "C" { GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, + GGML_OP_TOP_K, GGML_OP_LEAKY_RELU, GGML_OP_TRI, GGML_OP_FILL, @@ -2258,18 +2259,25 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); + // similar to ggml_top_k but implemented as `argsort` + `view` + GGML_API struct ggml_tensor * ggml_argsort_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k); + + // top k elements per row + // note: the resulting top k indices are in no particular order + GGML_API struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k); + GGML_API struct ggml_tensor * ggml_arange( struct ggml_context * ctx, float start, float stop, float step); - // top k elements per row - GGML_API struct ggml_tensor * ggml_top_k( - struct ggml_context * ctx, - struct ggml_tensor * a, - int k); - #define GGML_KQ_MASK_PAD 64 // q: [n_embd_k, n_batch, n_head, ne3 ] diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 08cf2b118b..48f4b7db69 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2207,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx, } /** - * @brief Initializes and caches sine/cosine positional encoding values - * (used in RoPE, Rotary Position Embedding) for attention layers. + * @brief Initializes and caches all intermediate tensors required for RoPE + * (Rotary Position Embedding), including support for Yarn, mRoPE, + * i-mRoPE, Neox repeat strategy, independent sectors, frequency factors, + * and multi-section rotary groups. * - * This function computes and caches the sin/cos values of - * θ = position * theta_scale for RoPE encoding. The cache is shared - * across attention layers, and only the first attention layer will - * trigger initialization. The cache includes repeated sin/cos values - * with different repeat methods depending on the @param is_neox flag. + * This function computes and caches the per-dimension θ coefficients used for + * Q/K rotary embedding. The cache is shared across layers, and recomputed only + * when any dependent parameter changes. * - * Steps performed by this function: - * 1. Identify whether the target tensor belongs to Q/K in attention - * and restrict computation to the first layer only. - * 2. Initialize the theta scale array (arange → power → freq scaling). - * 3. Allocate sin/cos caches if the max prompt length increases. - * 4. Compute θ = position * theta_scale. - * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor. - * 6. Expand sin/cos values by repeat or repeat_interleave depending - * on whether @param is_neox is enabled. + * The function now supports: + * - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor) + * - Per-dimension independent sector exponent rules (indep_sects + sections[]) + * - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope) + * - Frequency factor division (src2) + * - Neox / normal repeat expansion modes * - * @param ctx The CANN backend context, holding memory pool, - * stream, and persistent buffers for rope init/cache. - * @param dst The destination ggml_tensor whose computation - * depends on the RoPE values (usually Qcur/Kcur). - * @param theta_scale Scalar exponent base for computing theta scale values. - * @param freq_scale Frequency scaling factor, applied to theta scale. - * @param attn_factor Attention scaling factor, applied to sin/cos. - * @param is_neox Whether to use Neox-style repeat strategy - * (dim expansion vs repeat_interleave). + * @param ctx CANN backend context, containing memory pool, + * cached buffers, and runtime stream. + * @param dst Destination ggml_tensor whose computation + * depends on RoPE (typically Qcur or Kcur). + * @param corr_dims [low, high] Yarn correction range. + * @param ext_factor Yarn extrapolation strength. 0 = disabled. + * @param theta_scale Base multiplier for per-dimension θ exponent. + * @param freq_scale Global frequency scaling factor. + * @param attn_factor Optional scaling applied to sin/cos (if needed). + * @param is_neox Whether to use Neox-style dimension interleave. + * @param sections 4-way sector sizes for independent-section RoPE + * and multi-section mRoPE (t/h/w/e). + * @param mrope_used Whether to enable multi-section rotary embedding. + * @param is_imrope Whether to apply interleaved mRoPE rules. + * @param indep_sects Whether each dimension runs independent exponent + * resets based on @p sections. */ -static void aclnn_cache_init(ggml_backend_cann_context & ctx, - ggml_tensor * dst, - float * corr_dims, - float ext_factor, - float theta_scale, - float freq_scale, - float attn_factor, - bool is_neox) { +static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, + ggml_tensor * dst, + float * corr_dims, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + int sections[4], + bool mrope_used, + bool is_imrope, + bool indep_sects) { ggml_tensor * src0 = dst->src[0]; // input ggml_tensor * src1 = dst->src[1]; // position ggml_tensor * src2 = dst->src[2]; // freq_factors - if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor && - ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale && - ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) { + int64_t theta_scale_length = src0->ne[0] / 2; + int64_t position_length = dst->ne[2]; + + // TODO: check theta_scale_length and position_length. + if (src2 == nullptr && ctx.rope_cache.cached && + ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, + is_neox, indep_sects, mrope_used, is_imrope, sections)) { // use cache. return; } - int64_t theta_scale_length = src0->ne[0] / 2; - int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; - size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) }; + // Step0: calculate tensor shape. + int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; + size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float), + theta_scale_length * sizeof(float) }; GGML_ASSERT(src1->type == GGML_TYPE_I32); - int64_t position_length = src1->ne[0]; - int64_t position_ne[] = { 1, 1, position_length, 1 }; - size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; + int64_t position_ne[] = { 1, 1, position_length, 1 }; + size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; - int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 }; - size_t theta_nb[GGML_MAX_DIMS]; - theta_nb[0] = sizeof(float); + int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 }; + size_t cache_nb[GGML_MAX_DIMS]; + cache_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { - theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; + cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1]; } - // theta_scale arange, [0,1,...,ne00/2 - 1] + // Step1: Compute the coefficient of theta. During the cache_init process, aside from + // (1) multiplying by the position, + // (2) dividing by freq_factors, + // (3) computing the sine and cosine, + // the other parameters used in the computation generally do not change in most scenarios. + // Therefore, we can first compute this part of the result and then cache it. + + // Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor. acl_tensor_ptr acl_theta_scale_tensor; - // cache theta scale - if (ctx.rope_cache.theta_scale_length != theta_scale_length || - // theta_scale and freq_scale should not change during the current token inference process, - // so we can directly use == here instead of comparing the absolute difference. - ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) { - ctx.rope_cache.theta_scale_length = theta_scale_length; + bool theta_scale_updated = false; + if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale || + ctx.rope_cache.indep_sects != indep_sects) { + theta_scale_updated = true; + if (ctx.rope_cache.theta_scale_exp_host != nullptr) { + free(ctx.rope_cache.theta_scale_exp_host); + } + ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float)); + GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr); + if (!indep_sects) { + ctx.rope_cache.theta_scale_exp_host[0] = 1; + for (int i = 1; i < theta_scale_length; i++) { + ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale; + } + } else { + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + + ctx.rope_cache.theta_scale_exp_host[0] = 1; + for (int i = 1; i < theta_scale_length; i++) { + int sector = i % sect_dims; + if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) { + ctx.rope_cache.theta_scale_exp_host[i] = 1; + continue; + } + ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale; + } + } if (ctx.rope_cache.theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); @@ -2286,74 +2328,138 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), + ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float), + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + } - float start = 0; - float step = 1; - float stop = theta_scale_length; - float n_elements = theta_scale_length; - aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements); + // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. + bool yarn_ramp_tensor_updated = false; + ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); + acl_tensor_ptr acl_yarn_ramp_tensor; + if (ext_factor != 0 && + // TODO: check more parameter. + (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { + yarn_ramp_tensor_updated = true; - ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); - acl_tensor_ptr acl_yarn_ramp_tensor; - if (ext_factor != 0) { - // -rope_yarn_ramp - // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); - // return MIN(1, MAX(0, y)) - 1; - yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); - void * yarn_ramp_buffer = yarn_ramp_allocator.get(); - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); - float zero_value = 0, one_value = 1; - float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); - acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); - acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT); + // -rope_yarn_ramp + // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + // return MIN(1, MAX(0, y)) - 1; + yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); + void * yarn_ramp_buffer = yarn_ramp_allocator.get(); + acl_yarn_ramp_tensor = + ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + float zero_value = 0, one_value = 1; + float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); + acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); + acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(), - acl_yarn_ramp_tensor.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get()); + aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get()); - // theta_interp = freq_scale * theta_extrap; - // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; - // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; - // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); - // - // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse - // cache freq_scale + (freq_scale - 1) * ramp_mix - float freq_scale_1 = freq_scale - 1; - acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT); - acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); - } + // theta_interp = freq_scale * theta_extrap; + // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; + // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); + // + // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse + // cache freq_scale + (freq_scale - 1) * ramp_mix + float freq_scale_1 = freq_scale - 1; + acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT); + acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); + } - // power - acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(), - acl_theta_scale_tensor.get()); - - if (ext_factor != 0) { + // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. + if (ext_factor != 0) { + if (theta_scale_updated || yarn_ramp_tensor_updated) { + theta_scale_updated = true; aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get()); - } else if (freq_scale != 1) { - aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true); } } else { - // use cache + if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) { + theta_scale_updated = true; + aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true); + } + } + + // Nothing changed, use cache. + if (!theta_scale_updated) { acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); } + // Step 1.4: prepare select index if mrope + acl_tensor_ptr position_select_index_tensor; + if (mrope_used) { + if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] || + ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] || + ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) { + if (ctx.rope_cache.position_select_index_host != nullptr) { + free(ctx.rope_cache.position_select_index_host); + } + ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int)); + GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr); + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + // t,h,w,e + for (int i = 0; i < theta_scale_length; i++) { + int sector = i % sect_dims; + + if (is_imrope) { // qwen3vl apply interleaved mrope + if (sector % 3 == 1 && sector < 3 * sections[1]) { + ctx.rope_cache.position_select_index_host[i] = 1; + } else if (sector % 3 == 2 && sector < 3 * sections[2]) { + ctx.rope_cache.position_select_index_host[i] = 2; + } else if (sector % 3 == 0 && sector < 3 * sections[0]) { + ctx.rope_cache.position_select_index_host[i] = 0; + } else { + ctx.rope_cache.position_select_index_host[i] = 3; + } + } else { + if (sector >= sections[0] && sector < sec_w) { + ctx.rope_cache.position_select_index_host[i] = 1; + } else if (sector >= sec_w && sector < sec_e) { + ctx.rope_cache.position_select_index_host[i] = 2; + } else if (sector >= sec_e) { + ctx.rope_cache.position_select_index_host[i] = 3; + } else { + ctx.rope_cache.position_select_index_host[i] = 0; + } + } + } + + if (ctx.rope_cache.position_select_index != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index)); + } + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int), + ACL_MEM_MALLOC_HUGE_FIRST)); + + ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int), + ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int), + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + } + + position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32, + sizeof(int), theta_scale_ne, theta_scale_nb, 1); + } + + // Step2: divide by freq_factors ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool()); - // freq_factors if (src2) { freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float)); void * freq_fac_res_ptr = freq_fac_res_allocator.get(); @@ -2366,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor); } + // Step3: prepare position_tensor + acl_tensor_ptr acl_position_tensor; + ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool()); + if (mrope_used) { + // Step3.1: select current position; + // position : + // pos1: [[0, 1 ,2 ,3 ], + // pos2: [4, 5 ,6 ,7 ], + // pos3: [8, 9 ,10,11], + // pos4: [12,13,14,15] ] + // + // select index = [0, 1, 2, 2, 1, 0] + // + // selected_tensor: + // [[0, 1 ,2 ,3 ], + // [4, 5 ,6 ,7 ], + // [8, 9 ,10,11], + // [8, 9 ,10,11], + // [4, 5 ,6 ,7 ], + // [0, 1 ,2 ,3 ]] + // + // transpose, from [seq_len:dims] to [dims:seq_len] + // [0, 4, 8 ,8 ,4, 0], + // [1, 5, 9, 9, 5, 1], + // [2, 6, 10,10,6 ,2], + // [3, 7, 11,11,7 3 ]] + // + // multipy by theta_scale_tensor + // [theta_scale^0, theta_scale^1, ..., theta_scale ^ n] + + int64_t mrope_position_ne[] = { position_length, 4 }; + size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) }; + acl_tensor_ptr mrope_position = + ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), + mrope_position_ne, mrope_position_nb, 2); + + // selected position tensor's shape is a transpose of cache tensor. + int64_t selected_position_ne[] = { position_length, theta_scale_length }; + size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) }; + mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float)); + void * mrope_position_buffer = mrope_position_acllocator.get(); + acl_position_tensor = + ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(), + acl_position_tensor.get()); + + // transpose + int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 }; + size_t transposed_nb[GGML_MAX_DIMS]; + transposed_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1]; + } + + std::swap(transposed_ne[0], transposed_ne[2]); + std::swap(transposed_nb[0], transposed_nb[2]); + + acl_position_tensor = + ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS); + + } else { + // auto bcast. + acl_position_tensor = + ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), + position_ne, position_nb, GGML_MAX_DIMS); + } + + // Step4: multiply by the position + int64_t theta_length = theta_scale_length * position_length; + ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float)); + void * theta_buffer = theta_allocator.get(); + + acl_tensor_ptr acl_theta_tensor = + ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS); + aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get()); + + // Step5: calculate sin cos. // init sin_repeat && cos_repeat, only to accelerate first layer on each device if (position_length > ctx.rope_cache.position_length) { ctx.rope_cache.position_length = position_length; @@ -2382,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); } - // position - acl_tensor_ptr acl_position_tensor = - ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne, - position_nb, GGML_MAX_DIMS); - - // power * position - int64_t theta_length = theta_scale_length * position_length; - ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float)); - void * theta_buffer = theta_allocator.get(); - - acl_tensor_ptr acl_theta_tensor = - ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get()); - // sin/cos ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float)); void * sin_buffer = sin_allocator.get(); acl_tensor_ptr acl_sin_tensor = - ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get()); ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float)); void * cos_buffer = cos_allocator.get(); acl_tensor_ptr acl_cos_tensor = - ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get()); if (ext_factor != 0) { attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - // attn_factor + // Step 5: multiply by attn_factor if (attn_factor != 1) { aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true); aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true); } - int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 }; + int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2430,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - // repeat + // Step 6: repeat if (is_neox) { + // [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn] int64_t repeatsArray[] = { 1, 1, 1, 2 }; aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray); aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray); @@ -2439,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, int64_t num_repeats = 2; int64_t dim = 3; int64_t output_size = theta_scale_length * num_repeats; + // [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn] aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size); } - // Other layers use cache except first layer. - ctx.rope_cache.cached = true; - ctx.rope_cache.ext_factor = ext_factor; - ctx.rope_cache.theta_scale = theta_scale; - ctx.rope_cache.freq_scale = freq_scale; - ctx.rope_cache.attn_factor = attn_factor; - ctx.rope_cache.is_neox = is_neox; + // Update cached value. + ctx.rope_cache.cached = true; + ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, + indep_sects, mrope_used, is_imrope, sections); } #ifdef __cplusplus @@ -2475,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; // const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -2483,12 +2654,13 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_TENSOR_UNARY_OP_LOCALS - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); // TODO: n_dims <= ne0 GGML_ASSERT(n_dims == ne0); @@ -2499,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (mrope_used) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne0/2); + } + + if (is_imrope || mrope_used) { + is_neox = true; + } // init ctx.rope_cos/rope_sin cache - aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox); + aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision); int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; @@ -2658,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { return; #endif - // ggml_mode = 0 --> aclnn_model = 1 - int64_t acl_mode = mode == 0 ? 1 : mode; + int64_t acl_mode = is_neox ? 0 : 1; switch (src0->type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index d4ef24eaa7..b17445bb9a 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache { struct ggml_cann_rope_cache { ~ggml_cann_rope_cache() { - if (theta_scale_cache != nullptr) { + if (theta_scale_cache) { ACL_CHECK(aclrtFree(theta_scale_cache)); } - if (sin_cache != nullptr) { + if (sin_cache) { ACL_CHECK(aclrtFree(sin_cache)); } - if (cos_cache != nullptr) { + if (cos_cache) { ACL_CHECK(aclrtFree(cos_cache)); } + if (position_select_index) { + ACL_CHECK(aclrtFree(position_select_index)); + } + if (theta_scale_exp_host) { + free(theta_scale_exp_host); + } + if(position_select_index_host) { + free(position_select_index_host); + } } - void * theta_scale_cache = nullptr; - int64_t theta_scale_length = 0; + bool equal(int64_t theta_scale_length, + int64_t position_length, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + bool indep_sects, + bool mrope_used, + bool is_imrope, + int sections[4]) { + return this->theta_scale_length == theta_scale_length && this->position_length == position_length && + this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale && + this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects && + this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] && + this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3]; + } + + void set(int64_t theta_scale_length, + int64_t position_length, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + bool indep_sects, + bool mrope_used, + bool is_imrope, + int sections[4]) { + this->theta_scale_length = theta_scale_length; + this->position_length = position_length; + this->ext_factor = ext_factor; + this->theta_scale = theta_scale; + this->freq_scale = freq_scale; + this->attn_factor = attn_factor; + this->is_neox = is_neox; + this->indep_sects = indep_sects; + this->mrope_used = mrope_used; + this->is_imrope = is_imrope; + this->sections[0] = sections[0]; + this->sections[1] = sections[1]; + this->sections[2] = sections[2]; + this->sections[3] = sections[3]; + } + + // memory cache, prepare before inferencing. + void * theta_scale_cache = nullptr; + float * theta_scale_exp_host = nullptr; + int * position_select_index_host = nullptr; + void * position_select_index = nullptr; // sin/cos cache, used only to accelerate first layer on each device - void * sin_cache = nullptr; - void * cos_cache = nullptr; - int64_t position_length = 0; + void * sin_cache = nullptr; + void * cos_cache = nullptr; // Properties to check before reusing the sincos cache - bool cached = false; - float ext_factor = 0.0f; - float theta_scale = 0.0f; - float freq_scale = 0.0f; - float attn_factor = 0.0f; - bool is_neox = false; + int64_t theta_scale_length = 0; + int64_t position_length = 0; + bool cached = false; + float ext_factor = 0.0f; + float theta_scale = 0.0f; + float freq_scale = 0.0f; + float attn_factor = 0.0f; + bool is_neox = false; + bool indep_sects = false; + bool mrope_used = false; + int sections[4] = { 0, 0, 0, 0 }; + bool is_imrope = false; }; struct ggml_cann_tensor_cache { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8995a5c121..df28d67fb0 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2480,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return false; } - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } if (op->src[0]->ne[0] > 896) { return false; } diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index f4bccce277..8db7dc3d7c 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) include(CheckCXXSourceCompiles) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) - set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}") + string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}") + set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}") foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) set(ARM_FEATURE "HAVE_${feature}") check_cxx_source_compiles( diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b2acbb27dd..cc9cbf9173 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argsort(params, tensor); } break; + case GGML_OP_TOP_K: + { + ggml_compute_forward_top_k(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: @@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan( cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; + case GGML_OP_TOP_K: + { + cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks; + } break; case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne10 = node->src[1]->ne[0]; // DK diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 41e89d83c2..d405696539 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding( // ggml_compute_forward_argsort template -struct argsort_cmp { +struct cmp_argsort { const float * data; bool operator()(int32_t a, int32_t b) const { if constexpr (order == GGML_SORT_ORDER_ASC) { @@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32( switch (order) { case GGML_SORT_ORDER_ASC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; case GGML_SORT_ORDER_DESC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; default: @@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort( } } +// ggml_compute_forward_top_k + +struct cmp_top_k { + const float * data; + bool operator()(int32_t a, int32_t b) const { + return data[a] > data[b]; + } +}; + +static void ggml_compute_forward_top_k_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + const int top_k = ne0; + + int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int64_t i = ith; i < nr; i += nth) { + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne00; j++) { + tmp[j] = j; + } + + std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data}); + + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + + std::copy(tmp, tmp + top_k, dst_data); + + // emphasize that the order is not important + if (top_k > 1) { + std::swap(dst_data[0], dst_data[1]); + } + } +} + +void ggml_compute_forward_top_k( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_top_k_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 98a0eec16d..0fdfee7976 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 016993c39e..12b71a4898 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -455,146 +455,141 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const } inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { -#if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; - const int ggml_f16_step = 8 * ggml_f16_epr; +#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 8 * ggml_f16_epr; - GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np= (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; - for (int i = 0; i < np; i += ggml_f16_step) { - ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f16_step) { + ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); - GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); + GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); - ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); + ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); - GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); + GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); - ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); + ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); - GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); + GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); - ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); + ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); - GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); + GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); - ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); + ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); - GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); + GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); - ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); + ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); - GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); + GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); - ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); + ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); - GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); + GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); - ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); + ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); - GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); + GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); + } + const int np2 = (n & ~(ggml_f16_epr - 1)); + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + ry = GGML_F16x_VEC_FMA(ry, rx, vx); + + GGML_F16x_VEC_STORE(y + k, ry, 0); + } + + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + hy = svmad_f16_x(pg, hx, vx, hy); + svst1_f16(pg, (__fp16 *)(y + np2), hy); + } + np = n; +#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); + + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } + + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i , vl); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } +#elif defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } - const int np2 = (n & ~(ggml_f16_epr - 1)); - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); - ry = GGML_F16x_VEC_FMA(ry, rx, vx); - - GGML_F16x_VEC_STORE(y + k, ry, 0); - } - - if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); - hy = svmad_f16_x(pg, hx, vx, hy); - svst1_f16(pg, (__fp16 *)(y + np2), hy); - } - - #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); - const _Float16 scale = *(const _Float16*)(&s); - - // calculate step size - const int epr = __riscv_vsetvlmax_e16m4(); - const int step = epr * 2; - const int np = (n & ~(step - 1)); - - // unroll by 2 - for (int i = 0; i < np; i += step) { - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); - __asm__ __volatile__ ("" ::: "memory"); - - vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); - vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); - ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); - __asm__ __volatile__ ("" ::: "memory"); - } - - // leftovers - int vl; - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m4(n - i); - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i , vl); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); - } - #else - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); - } - #endif + } #else - // scalar - for (int i = 0; i < n; ++i) { + const int np = 0; +#endif + + // leftovers + for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } -#endif } // xs and vs are byte strides of x and v diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index caa08b360b..c0a9c2c08a 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -437,18 +437,27 @@ namespace ggml_cuda_mma { xi[0] = xs[0]; } #elif defined(AMD_WMMA_AVAILABLE) - if constexpr (I == 16 && J == 4) { - int64_t * xi = (int64_t *) t.x; - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); - xi[0] = xs[0]; - }else if constexpr (I == 16 && J == 8) { - int64_t * xi = (int64_t *) t.x; - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); - xi[0] = xs[0]; + if constexpr (std::is_same_v || std::is_same_v) { + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); - xi[1] = xs1[0]; - }else{ + } else if constexpr (std::is_same_v) { + if constexpr (I == 16 && J == 4) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + }else if constexpr (I == 16 && J == 8) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); + xi[1] = xs1[0]; + + }else{ + NO_DEVICE_CODE; + } + } else { NO_DEVICE_CODE; } #else diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 99760d56c7..82468b384e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3701,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 0eefc0b137..329500a03e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l return res; } +// note: reuse the argsort kernel for top_k +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + // note: the top_k kernel is always descending order + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 39ee6e3427..3976e622b9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index acf9dfd5fc..09b1b50311 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LEAKY_RELU: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 0fae97029f..342dc4f8c3 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -832,14 +832,19 @@ typedef struct { } ggml_metal_kargs_leaky_relu; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; } ggml_metal_kargs_argsort; typedef struct { @@ -851,6 +856,11 @@ typedef struct { uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; int32_t len; } ggml_metal_kargs_argsort_merge; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e46970dad4..9871e976f2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_argsort(ctx, idx); } break; + case GGML_OP_TOP_K: + { + n_fuse = ggml_metal_op_top_k(ctx, idx); + } break; case GGML_OP_LEAKY_RELU: { n_fuse = ggml_metal_op_leaky_relu(ctx, idx); @@ -3678,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_argsort args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nth, }; ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); ggml_metal_kargs_argsort_merge args_merge = { - .ne00 = ne00, - .ne01 = ne01, - .ne02 = ne02, - .ne03 = ne03, - .nb00 = nb00, - .nb01 = nb01, - .nb02 = nb02, - .nb03 = nb03, - .len = len, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ ne00, + /*.len =*/ len, }; // merges per row @@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + + // bitonic sort requires the number of elements to be power of 2 + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + // blocks per row + const int npr = (ne00 + nth - 1)/nth; + + const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]); + + if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { + std::swap(bid_dst, bid_tmp); + } + + const int top_k = ne0; + + ggml_metal_kargs_argsort args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices + }; + + if (npr > 1) { + args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k); + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + + int len = args.top_k; + + while (len < args.ne0) { + ggml_metal_op_concurrency_reset(ctx); + + // merges per row + const int nm = (args.ne0 + 2*len - 1) / (2*len); + + const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge))); + + ggml_metal_kargs_argsort_merge args_merge = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements + /*.len =*/ len, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_merge); + ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); + + std::swap(bid_dst, bid_tmp); + + len <<= 1; + } + + return 1; +} + int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 332e550ee7..b5546146e1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index f6033ddc97..70bf6f3d98 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ { res *= 2; } break; + case GGML_OP_TOP_K: + { + res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]); + } break; default: break; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 59e5761704..73b45c762d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32( ushort3 ntg[[threads_per_threadgroup]]) { // bitonic sort const int col = tpitg[0]; + const int ib = tgpig[0] / args.ne01; - const int i00 = (tgpig[0]/args.ne01)*ntg.x; - const int i01 = tgpig[0]%args.ne01; - const int i02 = tgpig[1]; - const int i03 = tgpig[2]; + const int i00 = ib*ntg.x; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); @@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32( } } + const int64_t i0 = ib*args.top_k; + // copy the result to dst without the padding - if (i00 + col < args.ne00) { - dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03; + if (i0 + col < args.ne0 && col < args.top_k) { + dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03; dst[col] = shmem_i32[col]; } @@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32( const int start = im * (2 * args.len); - const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); - const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); + const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); const int total = len0 + len1; device const int32_t * tmp0 = tmp + start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.ne0 + + i02*args.ne0*args.ne01 + + i03*args.ne0*args.ne01*args.ne02; device const int32_t * tmp1 = tmp0 + args.len; dst += start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.top_k + + i02*args.top_k*args.ne01 + + i03*args.top_k*args.ne01*args.ne02; device const float * src0_row = (device const float *)(src0 + args.nb01*i01 @@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32( const int chunk = (total + ntg.x - 1) / ntg.x; const int k0 = tpitg.x * chunk; - const int k1 = min(k0 + chunk, total); + const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); + + if (k0 >= args.top_k) { + return; + } if (k0 >= total) { return; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6cf15b43bb..7f2cf795c9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -409,6 +409,7 @@ enum shader_reduction_mode { // argsort pipelines for up to 1<<10 invocations per workgroup static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t num_topk_moe_pipelines = 10; +static constexpr uint32_t num_topk_pipelines = 11; static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, @@ -515,6 +516,7 @@ struct vk_device_struct { bool single_queue; bool support_async; uint32_t subgroup_size; + uint32_t subgroup_size_log2; uint32_t shader_core_count; bool uma; bool prefer_host_memory; @@ -704,7 +706,9 @@ struct vk_device_struct { vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; + vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; @@ -1204,6 +1208,15 @@ struct vk_op_argsort_push_constants { uint32_t inner_end; }; +struct vk_op_topk_push_constants { + uint32_t orig_ncols; + uint32_t ncols_input; + uint32_t ncols_output; + uint32_t nrows; + uint32_t first_pass; + uint32_t last_pass; +}; + struct vk_op_im2col_push_constants { uint64_t dst_addr; uint32_t batch_offset; uint32_t offset_delta; @@ -3964,10 +3977,29 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true); } + for (uint32_t i = 0; i < num_topk_pipelines; ++i) { + const uint32_t BLOCK_SIZE = 1u << i; + const uint32_t NCOLS_PADDED_LOG2 = i; + if (i <= device->max_workgroup_size_log2) { + uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE + + sizeof(int) * device->subgroup_size + + 2 * sizeof(int) + + (BLOCK_SIZE / device->subgroup_size) * sizeof(int); + if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot && + nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size); + } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); + } + } + } + ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); #define IM2COL(bda) \ @@ -4333,6 +4365,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); device->subgroup_size = subgroup_props.subgroupSize; + device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size))); device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; if (sm_builtins) { device->shader_core_count = sm_props.shaderSMCount; @@ -8457,6 +8490,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_sum_rows_f32; } return nullptr; + case GGML_OP_CUMSUM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cumsum_f32; + } + return nullptr; case GGML_OP_ARGMAX: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { return ctx->device->pipeline_argmax_f32; @@ -8821,6 +8859,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: case GGML_OP_MEAN: case GGML_OP_ARGMAX: { @@ -10134,6 +10173,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c } } +static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + uint32_t ncols = src0->ne[0]; + uint32_t nrows = ggml_nrows(src0); + uint32_t k = dst->ne[0]; + + vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 }; + + // Reserve space for ivec2 per element, double buffered + const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int); + const size_t x_sz = dbl_buf_size * 2; + uint32_t dbl_buf_index = 0; + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + std::array elements; + elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = 1; + + uint32_t num_elements = ncols; + + // Each iteration reduces a workgroup's worth of elements down to the K + // largest elements. Repeat until we have the top K elements. + // Need to do at least one iteration to write out the results. + bool done_one_iter = false; + while (num_elements > k || !done_one_iter) { + done_one_iter = true; + + // Prefer going as small as num_topk_pipelines - 3 for perf reasons. + // But if K is larger, then we need a larger workgroup + uint32_t max_pipeline = num_topk_pipelines - 3; + uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1; + // require full subgroup + min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2); + + uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements))); + pipeline_idx = std::min(pipeline_idx, max_pipeline); + pipeline_idx = std::max(pipeline_idx, min_pipeline); + + if (num_elements > (1u << pipeline_idx)) { + // If we could finish on this loop iteration (i.e. a single workgroup) + // then do so. It's better than the overhead of another pass. + for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) { + if (num_elements <= (1u << i)) { + pipeline_idx = i; + break; + } + } + } + + vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx]; + // If the device doesn't support a pipeline this large, use smaller + while (!pipeline) { + pipeline_idx--; + GGML_ASSERT(pipeline_idx >= min_pipeline); + pipeline = ctx->device->pipeline_topk_f32[pipeline_idx]; + } + + vk_op_topk_push_constants pc2 = pc; + pc2.ncols_input = num_elements; + + // Number of elements remaining after this pass + uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]); + + vk_subbuffer src_buf; + vk_subbuffer dst_buf; + + if (num_elements == ncols) { + pc2.first_pass = 1; + src_buf = ggml_vk_tensor_subbuffer(ctx, src0); + } else { + src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size }; + } + if (num_dst_elements == k) { + pc2.last_pass = 1; + dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + } else { + dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size }; + } + + elements[0] = num_elements; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements); + num_elements = num_dst_elements; + dbl_buf_index ^= 1; + if (num_elements > k) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + ctx->prealloc_x_need_sync = true; +} + static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p); @@ -10150,6 +10287,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p); } +static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p); +} + static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }); } @@ -11741,6 +11883,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ggml_vk_argsort(ctx, compute_ctx, src0, node); } + break; + case GGML_OP_TOP_K: + ggml_vk_topk(ctx, compute_ctx, src0, node); + break; case GGML_OP_SUM: ggml_vk_sum(ctx, compute_ctx, src0, node); @@ -11749,6 +11895,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node); + break; + case GGML_OP_CUMSUM: + ggml_vk_cumsum(ctx, compute_ctx, src0, node); + break; case GGML_OP_MEAN: ggml_vk_mean(ctx, compute_ctx, src0, node); @@ -13008,24 +13158,6 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * return false; }; - // This function tries to reorder the graph to allow nodes to run in parallel. - // This helps with small batches, but for large batches its a slowdown, probably - // due to cache contention. So only reorder if the majority of nodes have few rows. - int num_small_nodes = 0; - int num_counted_nodes = 0; - for (int i = 0; i < graph->n_nodes; ++i) { - if (!is_empty(graph->nodes[i]) && - graph->nodes[i]->op != GGML_OP_SET_ROWS) { - if (ggml_nrows(graph->nodes[i]) <= 8) { - num_small_nodes++; - } - num_counted_nodes++; - } - } - if (num_small_nodes < num_counted_nodes / 2) { - return; - } - std::vector new_order; std::vector used(graph->n_nodes, false); std::set used_node_set; @@ -13769,6 +13901,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return op->ne[0] <= (1 << device->max_workgroup_size_log2); } } + case GGML_OP_TOP_K: + { + if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { + return false; + } + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + // We could potentially support larger, using argsort to sort the + // whole thing. Not clear if this is needed. + uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; + if (min_pipeline >= num_topk_pipelines || + !device->pipeline_topk_f32[min_pipeline]) { + return false; + } + } + return true; case GGML_OP_UPSCALE: case GGML_OP_ACC: case GGML_OP_CONCAT: @@ -13786,6 +13934,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_CUMSUM: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); + } + return false; + } case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: @@ -14432,10 +14589,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ARGSORT) { tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_TOP_K) { + tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]); } else if (tensor->op == GGML_OP_SUM) { tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_CUMSUM) { + tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_MEAN) { tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_ARGMAX) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp new file mode 100644 index 0000000000..a4c8fc354e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp @@ -0,0 +1,69 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE]; +shared FLOAT_TYPE last_sum; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + uint subgroup_id = tid / SUBGROUP_SIZE; + + if (tid == 0) { + last_sum = 0; + } + + uint col = tid; + uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE); + for (int i = 0; i < num_iter; ++i) { + FLOAT_TYPE v = 0; + if (col < p.n_cols) { + v = FLOAT_TYPE(data_a[src_idx + col]); + } + v = subgroupInclusiveAdd(v); + + // Store the largest partial sum for each subgroup, then add the partials for all + // lower subgroups and the final partial sum from the previous iteration. + if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { + partial[subgroup_id] = v; + } + barrier(); + for (int j = 0; j < subgroup_id; ++j) { + v += partial[j]; + } + v += last_sum; + barrier(); + if (tid == BLOCK_SIZE - 1) { + last_sum = v; + } + if (col < p.n_cols) { + data_d[dst_idx + col] = D_TYPE(v); + } + col += BLOCK_SIZE; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp index bc22aa7bd7..13ba2e99dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -1,6 +1,7 @@ #version 450 #include "types.glsl" +#include "sum_rows.glsl" #extension GL_EXT_control_flow_attributes : enable @@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (push_constant) uniform parameter -{ - uint n_cols; - uint ne01, ne02; - uint nb01, nb02, nb03; - uint nb11, nb12, nb13; - float weight; - uint misalign_offsets; - uint ne0_12mp, ne0_12L; - uint ne0_1mp, ne0_1L; -} p; - -uint get_aoffset() { return p.misalign_offsets >> 16; } -uint get_doffset() { return p.misalign_offsets & 0xFFFF; } - -// see init_fastdiv_values in ggml-vulkan.cpp -uint fastdiv(uint n, uint mp, uint L) { - uint msbs, lsbs; - // msbs = mulhi(n, mp) - umulExtended(n, mp, msbs, lsbs); - return (msbs + n) >> L; -} - - shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl new file mode 100644 index 0000000000..2b841baa6b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl @@ -0,0 +1,25 @@ + +// vk_op_sum_rows_push_constants +layout (push_constant) uniform parameter +{ + uint n_cols; + uint ne01, ne02; + uint nb01, nb02, nb03; + uint nb11, nb12, nb13; + float weight; + uint misalign_offsets; + uint ne0_12mp, ne0_12L; + uint ne0_1mp, ne0_1L; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp new file mode 100644 index 0000000000..cd858b7d32 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp @@ -0,0 +1,113 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// Input can either be the source (A) or intermediate values (S). +// Similarly, output can be either destination (D) or intermediate values (S). +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer S {ivec2 data_s[];}; +layout (binding = 1) writeonly buffer D {int data_d[];}; +layout (binding = 1) writeonly buffer T {ivec2 data_t[];}; + +layout (push_constant) uniform parameter { + uint orig_ncols; + uint ncols_input; + uint ncols_output; + uint nrows; + uint first_pass; + uint last_pass; +} p; + +// pairs of (gid, value) +shared ivec2 dst_row[BLOCK_SIZE]; + +void topk(bool needs_bounds_check, const uint row) { + const int col = int(gl_LocalInvocationID.x); + + // initialize indices + if (gl_GlobalInvocationID.x < p.ncols_input) { + if (p.first_pass != 0) { + const uint row_offset = row * p.ncols_input; + dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); + } else { + const uint row_offset = row * p.orig_ncols; + dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x]; + } + } else { + dst_row[col] = ivec2(p.orig_ncols, 0); + } + barrier(); + + if (p.ncols_output == 1) { + // Fast path for single output - just do a max reduction + [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { + if (col < s) { + ivec2 a = dst_row[col]; + ivec2 b = dst_row[col + s]; + if (a.x >= p.orig_ncols || + b.x < p.orig_ncols && b.y > a.y) { + dst_row[col] = b; + } + } + barrier(); + } + } else { + // bitonic sort on this group of elements + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; + for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { + uint num_inner_loop_iters = outer_idx + 1; + for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { + const int ixj = int(col ^ j); + + int idx_0 = (col & k) == 0 ? col : ixj; + int idx_1 = (col & k) == 0 ? ixj : col; + + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; + } + + barrier(); + } + } + } + + if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (p.last_pass != 0) { + const uint row_offset = row * p.ncols_output; + data_d[row_offset + col] = dst_row[col].x; + } else { + const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; + data_t[row_offset + col] = dst_row[col]; + } + } +} + +void main() { + // Fast path for fully occupied workgroups + if ((p.ncols_input % BLOCK_SIZE) == 0) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp new file mode 100644 index 0000000000..c902e60237 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -0,0 +1,199 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_debug_printf : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int SUBGROUP_SIZE = 32; +layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// Input can either be the source (A) or intermediate values (S). +// Similarly, output can be either destination (D) or intermediate values (S). +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer S {ivec2 data_s[];}; +layout (binding = 1) writeonly buffer D {int data_d[];}; +layout (binding = 1) writeonly buffer T {ivec2 data_t[];}; + +layout (push_constant) uniform parameter { + uint orig_ncols; + uint ncols_input; + uint ncols_output; + uint nrows; + uint first_pass; + uint last_pass; +} p; + +// pairs of (gid, value) +shared ivec2 dst_row[BLOCK_SIZE]; + +shared int counts[SUBGROUP_SIZE]; +shared int sh_min_idx; +shared uint sh_total; +shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE]; + +// Map float values to uint such that comparisons still work. +// Positive values set the high bit, negative values are inverted. +// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places. +uint f2ui(float x) { + uint y = floatBitsToUint(x); + if ((y & 0x80000000) != 0) { + y ^= ~0; + } else { + y |= 0x80000000; + } + return y; +} + +void topk(const uint row) { + const int tid = int(gl_LocalInvocationID.x); + + // initialize indices + if (gl_GlobalInvocationID.x < p.ncols_input) { + if (p.first_pass != 0) { + const uint row_offset = row * p.ncols_input; + dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); + } else { + const uint row_offset = row * p.orig_ncols; + dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x]; + } + } else { + dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf + } + barrier(); + + if (p.ncols_output == 1) { + // Fast path for single output - just do a max reduction + [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { + if (tid < s) { + ivec2 a = dst_row[tid]; + ivec2 b = dst_row[tid + s]; + if (a.x >= p.orig_ncols || + b.x < p.orig_ncols && b.y > a.y) { + dst_row[tid] = b; + } + } + barrier(); + } + } else { + // Do an N-ary search to find the K-th largest value. + // We remap the float values to be comparable as unsigned integers, + // and split the range into 2^N smaller ranges where N is the + // subgroup size. Count how many values are in each range, if the K-th + // largest value is in the middle of one of thee ranges then repeat + // and split again. + + // Mask is the current set of bits we're searching. Shift is the LSB index. + int shift = 32 - SUBGROUP_SIZE_LOG2; + uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift; + + // The current range. + uint range_min = 0; + uint range_max = 0xFF800000; + // How many are above the current range, and how many we need to find. + uint total = 0; + uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); + + while (mask != 0) { + barrier(); + // Initialize bucket counts to zero. + if (tid < SUBGROUP_SIZE) { + counts[tid] = 0; + } + barrier(); + // Count how many values are in each bucket. + if (tid < p.ncols_input) { + float y = intBitsToFloat(dst_row[tid].y); + uint fy = f2ui(y); + if (fy >= range_min && fy < range_max) { + uint bucket = (fy & mask) >> shift; + atomicAdd(counts[bucket], 1); + } + } + barrier(); + + // On the first subgroup, do a scan to count (from the top down) how + // many elements are in the top N buckets. Find the index of the first + // that is over the limit. Copy it to the other invocations through + // shared memory. + if (tid < SUBGROUP_SIZE) { + uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid]; + partial_sum = subgroupInclusiveAdd(partial_sum) + total; + uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit)); + if (tid == t) { + sh_min_idx = int(SUBGROUP_SIZE - 1 - t); + sh_total = partial_sum; + } + } + barrier(); + int min_idx = sh_min_idx; + total = sh_total; + + // Update the range, and break if we've found the K-th largest. + range_max = range_min + ((min_idx + 1) << shift); + range_min = range_min + (min_idx << shift); + + if (total == p.ncols_output) { + break; + } + total -= counts[min_idx]; + mask >>= SUBGROUP_SIZE_LOG2; + shift -= SUBGROUP_SIZE_LOG2; + if (shift < 0) { + shift = 0; + } + } + + ivec2 v = dst_row[tid]; + + // We need to compact these values to the start of the dst_row array. + // Have each subgroup count how many items it'll store, so other + // subgroups can compute their base offset. + bool top = f2ui(intBitsToFloat(v.y)) >= range_min; + uvec4 b = subgroupBallot(top); + uint bit_count = subgroupBallotBitCount(b); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count; + } + barrier(); + + uint out_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + } + } + + uint bit_count_ex = subgroupBallotExclusiveBitCount(b); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex] = v; + } + + barrier(); + } + + if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (p.last_pass != 0) { + const uint row_offset = row * p.ncols_output; + data_d[row_offset + tid] = dst_row[tid].x; + } else { + const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; + data_t[row_offset + tid] = dst_row[tid]; + } + } +} + +void main() { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bc992068f8..4a802ab1c2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -913,9 +913,13 @@ void process_shaders() { string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); + string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}}); + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); + string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); for (std::string dim_str : {"", "_3d"}) { for (bool bda : {false, true}) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a5846a2393..b99345a2e9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", + "TOP_K", "LEAKY_RELU", "TRI", "FILL", @@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", + "top_k(x)", "leaky_relu(x)", "tri(x)", "fill(x, c)", @@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll( return result; } -// ggml_arange - -struct ggml_tensor * ggml_arange( - struct ggml_context * ctx, - float start, - float stop, - float step) { - GGML_ASSERT(stop > start); - - const int64_t steps = (int64_t) ceilf((stop - start) / step); - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); - - ggml_set_op_params_f32(result, 0, start); - ggml_set_op_params_f32(result, 1, stop); - ggml_set_op_params_f32(result, 2, step); - - result->op = GGML_OP_ARANGE; - - return result; -} - // ggml_timestep_embedding struct ggml_tensor * ggml_timestep_embedding( @@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort( struct ggml_tensor * a, enum ggml_sort_order order) { GGML_ASSERT(a->ne[0] <= INT32_MAX); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); ggml_set_op_params_i32(result, 0, (int32_t) order); @@ -5149,6 +5130,24 @@ struct ggml_tensor * ggml_argsort( return result; } +// ggml_argsort_top_k + +struct ggml_tensor * ggml_argsort_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k) { + GGML_ASSERT(a->ne[0] >= k); + + struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + // ggml_top_k struct ggml_tensor * ggml_top_k( @@ -5157,12 +5156,32 @@ struct ggml_tensor * ggml_top_k( int k) { GGML_ASSERT(a->ne[0] >= k); - struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]); - result = ggml_view_4d(ctx, result, - k, result->ne[1], result->ne[2], result->ne[3], - result->nb[1], result->nb[2], result->nb[3], - 0); + result->op = GGML_OP_TOP_K; + result->src[0] = a; + + return result; +} + +// ggml_arange + +struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step) { + GGML_ASSERT(stop > start); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); + + ggml_set_op_params_f32(result, 0, start); + ggml_set_op_params_f32(result, 1, stop); + ggml_set_op_params_f32(result, 2, step); + + result->op = GGML_OP_ARANGE; return result; } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 642ae2ae59..57ca2035fe 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -4,6 +4,7 @@ import logging import os import shutil import struct +import sys import tempfile from dataclasses import dataclass from enum import Enum, auto @@ -372,8 +373,10 @@ class GGUFWriter: self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, ) -> None: - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ + (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) if self.use_temp_file and self.temp_file is None: fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) fp.seek(0) @@ -399,8 +402,10 @@ class GGUFWriter: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ + (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) file_id = -1 for i, tensors in enumerate(self.tensors): diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 4a89d08f80..88f45862b6 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -16,7 +16,7 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h", + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h", } for url, filename in vendor.items(): diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 650e40ec6f..1d012e09ab 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // organize experts into n_expert_groups ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens] - ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] + ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens] // get top n_group_used expert groups group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens] group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens] - ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] + ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] cb(expert_groups, "ffn_moe_group_topk", il); // mask out the other groups @@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( } // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index fd48d25475..d7ac2bc178 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -39,6 +39,7 @@ #include #include #include +#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) { return mse_a_b / mse_a_0; } +// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap) +static double jdst(const int32_t * a, const int32_t * b, size_t n) { + std::unordered_map set_a; + std::unordered_map set_b; + + for (size_t i = 0; i < n; ++i) { + set_a[a[i]]++; + set_b[b[i]]++; + } + + size_t diff = 0; + + for (const auto & p : set_a) { + const int64_t na = p.second; + const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0; + + diff += std::abs(na - nb); + } + + for (const auto & p : set_b) { + if (set_a.find(p.first) == set_a.end()) { + diff += p.second; + } + } + + return (double) diff / (2*n); +} + // maximum absolute asymmetry between a and b // asymmetry: (a - b) / (a + b) // This is more stable than relative error if one of the values fluctuates towards zero. @@ -1051,6 +1080,14 @@ struct test_case { return 1e-4; } + virtual double max_err() { + return max_nmse_err(); + } + + virtual double err(const float * a, const float * b, size_t n) { + return nmse(a, b, n); + } + virtual float grad_eps() { return 1e-1f; } @@ -1257,16 +1294,16 @@ struct test_case { // compare struct callback_userdata { bool ok; - double max_err; + test_case * tc; ggml_backend_t backend1; ggml_backend_t backend2; }; callback_userdata ud { true, - max_nmse_err(), + this, backend1, - backend2 + backend2, }; auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { @@ -1314,9 +1351,9 @@ struct test_case { } } - double err = nmse(f1.data(), f2.data(), f1.size()); - if (err > ud->max_err) { - printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + double err = ud->tc->err(f1.data(), f2.data(), f1.size()); + if (err > ud->tc->max_err()) { + printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err()); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -4943,7 +4980,71 @@ struct test_argsort : public test_case { } }; -struct test_topk_moe: public test_case { +// GGML_OP_TOP_K +struct test_top_k : public test_case { + const ggml_type type; + const std::array ne; + const int k; + + std::string vars() override { + return VARS_TO_STR3(type, ne, k); + } + + test_top_k(ggml_type type = GGML_TYPE_F32, + std::array ne = {16, 10, 10, 10}, + int k = 4) + : type(type), ne(ne), k(k) {} + + double max_err() override { + return 0.0; + } + + double err(const float * a, const float * b, size_t n) override { + std::vector ia(n); + std::vector ib(n); + + double diff = 0.0f; + + for (size_t i = 0; i < n; i++) { + ia[i] = (int32_t) a[i]; + ib[i] = (int32_t) b[i]; + + // penalize the result if the data is not integer valued + diff += std::fabs(a[i] - ia[i]); + diff += std::fabs(b[i] - ib[i]); + } + + return diff + jdst(ia.data(), ib.data(), n); + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_top_k(ctx, a, k); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // initialize with unique values to avoid ties + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } + } +}; + +struct test_topk_moe : public test_case { const std::array ne; const int n_expert_used; const bool with_norm; @@ -4976,7 +5077,7 @@ struct test_topk_moe: public test_case { ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); - ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] @@ -7534,6 +7635,31 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) } + for (int i = 0; i < 20; ++i) { + for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) { + if (k <= 1<> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + for (auto k : {1, 10, 40}) { + for (auto nrows : {1, 16}) { + for (auto cols : {k, 1000, 65000, 200000}) { + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k)); + } + } + } return test_cases; } diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index 8e0f8064f7..0fa1cd9831 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -31,13 +31,16 @@ if (LLAMA_BUILD_BORINGSSL) message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}") - include(FetchContent) - FetchContent_Declare( - boringssl + set(BORINGSSL_ARGS GIT_REPOSITORY ${BORINGSSL_GIT} GIT_TAG ${BORINGSSL_VERSION} - PATCH_COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_SOURCE_DIR}/patch-boringssl.cmake" ) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL) + endif() + + include(FetchContent) + FetchContent_Declare(boringssl ${BORINGSSL_ARGS}) set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) set(SAVED_BUILD_TESTING ${BUILD_TESTING}) @@ -45,7 +48,15 @@ if (LLAMA_BUILD_BORINGSSL) set(BUILD_SHARED_LIBS OFF) set(BUILD_TESTING OFF) - FetchContent_MakeAvailable(boringssl) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + FetchContent_MakeAvailable(boringssl) + else() + FetchContent_GetProperties(boringssl) + if(NOT boringssl_POPULATED) + FetchContent_Populate(boringssl) + add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() + endif() set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) set(BUILD_TESTING ${SAVED_BUILD_TESTING}) diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 5432db69b4..b86e6a2310 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1087,22 +1087,30 @@ int getaddrinfo_with_timeout(const char *node, const char *service, // Fallback implementation using thread-based timeout for other Unix systems struct GetAddrInfoState { + ~GetAddrInfoState() { + if (info) { freeaddrinfo(info); } + } + std::mutex mutex; std::condition_variable result_cv; bool completed = false; int result = EAI_SYSTEM; - std::string node = node; - std::string service = service; - struct addrinfo hints = hints; + std::string node; + std::string service; + struct addrinfo hints; struct addrinfo *info = nullptr; }; // Allocate on the heap, so the resolver thread can keep using the data. auto state = std::make_shared(); + state->node = node; + state->service = service; + state->hints = *hints; - std::thread resolve_thread([=]() { - auto thread_result = getaddrinfo( - state->node.c_str(), state->service.c_str(), hints, &state->info); + std::thread resolve_thread([state]() { + auto thread_result = + getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints, + &state->info); std::lock_guard lock(state->mutex); state->result = thread_result; @@ -1120,6 +1128,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service, // Operation completed within timeout resolve_thread.join(); *res = state->info; + state->info = nullptr; // Pass ownership to caller return state->result; } else { // Timeout occurred @@ -4970,7 +4979,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection, if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } // Prepare additional headers - if (close_connection || req.get_header_value("Connection") == "close") { + if (close_connection || req.get_header_value("Connection") == "close" || + 400 <= res.status) { // Don't leave connections open after errors res.set_header("Connection", "close"); } else { std::string s = "timeout="; @@ -5173,7 +5183,11 @@ bool Server::read_content_core( size_t /*len*/) { return receiver(buf, n); }; } - if (req.method == "DELETE" && !req.has_header("Content-Length")) { + // RFC 7230 Section 3.3.3: If this is a request message and none of the above + // are true (no Transfer-Encoding and no Content-Length), then the message + // body length is zero (no message body is present). + if (!req.has_header("Content-Length") && + !detail::is_chunked_transfer_encoding(req.headers)) { return true; } @@ -5681,8 +5695,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, // Check if the request URI doesn't exceed the limit if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); res.status = StatusCode::UriTooLong_414; output_error_log(Error::ExceedUriMaxLength, &req); return write_response(strm, close_connection, req, res); @@ -6666,11 +6678,13 @@ bool ClientImpl::write_request(Stream &strm, Request &req, return true; } -std::unique_ptr ClientImpl::send_with_content_provider( +std::unique_ptr +ClientImpl::send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error) { + const std::string &content_type, ContentReceiver content_receiver, + Error &error) { if (!content_type.empty()) { req.set_header("Content-Type", content_type); } #ifdef CPPHTTPLIB_ZLIB_SUPPORT @@ -6743,15 +6757,24 @@ std::unique_ptr ClientImpl::send_with_content_provider( } } + if (content_receiver) { + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + } + auto res = detail::make_unique(); return send(req, *res, error) ? std::move(res) : nullptr; } -Result ClientImpl::send_with_content_provider( +Result ClientImpl::send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress) { + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress) { Request req; req.method = method; req.headers = headers; @@ -6763,9 +6786,10 @@ Result ClientImpl::send_with_content_provider( auto error = Error::Success; - auto res = send_with_content_provider( + auto res = send_with_content_provider_and_receiver( req, body, content_length, std::move(content_provider), - std::move(content_provider_without_length), content_type, error); + std::move(content_provider_without_length), content_type, + std::move(content_receiver), error); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, @@ -7094,6 +7118,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7102,6 +7135,15 @@ Result ClientImpl::Post(const std::string &path, progress); } +Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7142,17 +7184,18 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7160,18 +7203,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7181,10 +7246,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "POST", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7246,6 +7311,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7254,6 +7328,15 @@ Result ClientImpl::Put(const std::string &path, progress); } +Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7294,17 +7377,18 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7312,18 +7396,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7333,10 +7439,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PUT", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7400,6 +7506,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7408,6 +7523,15 @@ Result ClientImpl::Patch(const std::string &path, progress); } +Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7448,18 +7572,18 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body, - content_length, nullptr, nullptr, - content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7467,18 +7591,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7488,10 +7634,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PATCH", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -8883,12 +9029,28 @@ Result Client::Post(const std::string &path, size_t content_length, return cli_->Post(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Post(path, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8897,6 +9059,15 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8904,6 +9075,14 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Params ¶ms) { return cli_->Post(path, params); } @@ -8938,8 +9117,8 @@ Result Client::Post(const std::string &path, const Headers &headers, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress) { - return cli_->Post(path, headers, body, content_type, content_receiver, - progress); + return cli_->Post(path, headers, body, content_type, + std::move(content_receiver), progress); } Result Client::Put(const std::string &path) { return cli_->Put(path); } @@ -8976,12 +9155,28 @@ Result Client::Put(const std::string &path, size_t content_length, return cli_->Put(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Put(path, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8990,6 +9185,15 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8997,6 +9201,14 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Params ¶ms) { return cli_->Put(path, params); } @@ -9072,12 +9284,28 @@ Result Client::Patch(const std::string &path, size_t content_length, return cli_->Patch(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Patch(path, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -9086,6 +9314,15 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -9093,6 +9330,14 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Params ¶ms) { return cli_->Patch(path, params); } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 083f795036..c9bd9fd86b 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.27.0" -#define CPPHTTPLIB_VERSION_NUM "0x001B00" +#define CPPHTTPLIB_VERSION "0.28.0" +#define CPPHTTPLIB_VERSION_NUM "0x001C00" /* * Platform compatibility check @@ -257,6 +257,7 @@ using socklen_t = int; #include #ifdef __linux__ #include +#undef _res // Undefine _res macro to avoid conflicts with user code (#2278) #endif #include #include @@ -1421,14 +1422,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1439,14 +1444,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1457,14 +1466,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1712,17 +1725,19 @@ private: template void setup_redirect_client(ClientType &client); bool handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); - std::unique_ptr send_with_content_provider( + std::unique_ptr send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error); - Result send_with_content_provider( + const std::string &content_type, ContentReceiver content_receiver, + Error &error); + Result send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress); + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress); ContentProviderWithoutLength get_multipart_content_provider( const std::string &boundary, const UploadFormDataItems &items, const FormDataProviderItems &provider_items) const; @@ -1775,14 +1790,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1793,14 +1812,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1811,14 +1834,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); diff --git a/vendor/cpp-httplib/patch-boringssl.cmake b/vendor/cpp-httplib/patch-boringssl.cmake deleted file mode 100644 index 2914e1dddb..0000000000 --- a/vendor/cpp-httplib/patch-boringssl.cmake +++ /dev/null @@ -1,6 +0,0 @@ -# Remove bssl -file(READ "CMakeLists.txt" content) -string(REPLACE "add_executable(bssl" "#add_executable(bssl" content "${content}") -string(REPLACE "target_link_libraries(bssl" "#target_link_libraries(bssl" content "${content}") -string(REPLACE "install(TARGETS bssl" "#install(TARGETS bssl" content "${content}") -file(WRITE "CMakeLists.txt" "${content}")