diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 434023dd22..2857e080b4 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -146,9 +146,7 @@ void ggml_cann_op_unary_gated( unary_op(ctx, acl_src0, acl_dst); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1); - ggml_cann_release_resources(ctx, acl_src0, acl_dst); - if(src1) - ggml_cann_release_resources(ctx, acl_src1); + ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst); } /** @@ -894,14 +892,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, } /** - * @brief Get or expand a cached float32 tensor filled with a scalar value. + * @brief Get or expand a cached tensor filled with a scalar value. * - * This function manages cached device memory for float32 tensors. If the current + * This function manages cached device memory for tensors. If the current * cache size is insufficient for the requested tensor shape, the old memory will - * be released and new memory will be allocated. The allocated buffer is then - * initialized either with zeros (when @p value == 0.0f) or with the given scalar - * value using CANN operations. Finally, an aclTensor object is created from the - * cached memory and returned. + * be released and new memory will be allocated. The allocated buffer is + * initialized with the given scalar value using CANN operations. + * Finally, an aclTensor object is created from the cached memory and returned. * * @param ctx The CANN backend context that manages device memory. * @param buffer A pointer to the cached device buffer (will be allocated @@ -910,17 +907,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, * updated when the cache is expanded. * @param ne The tensor shape array (number of elements in each dimension). * @param nb The stride size for each dimension. + * @param dtype Data type of cached tensor. * @param dims The number of tensor dimensions. * @param value The scalar value used to fill the tensor (supports zero * initialization via memset or arbitrary values via fill_scalar). * @return An aclTensor pointer created from the cached buffer. */ -static aclTensor* get_f32_cache_acl_tensor( +static aclTensor* get_cache_acl_tensor( ggml_backend_cann_context& ctx, void** buffer, int64_t &cache_element, int64_t* ne, size_t* nb, + ggml_type dtype, int64_t dims, float value) { // Calculate total number of elements @@ -928,7 +927,7 @@ static aclTensor* get_f32_cache_acl_tensor( for (int i = 0; i < dims; i++) { n_element *= ne[i]; } - size_t size = n_element * sizeof(float); + size_t size = n_element * ggml_type_size(dtype); // Allocate or expand cache if needed if (cache_element < n_element) { @@ -941,19 +940,17 @@ static aclTensor* get_f32_cache_acl_tensor( cache_element = n_element; // Initialize cache - if (value == 0.0f) { - ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream())); - } else { - int64_t pool_ne[1] = { n_element }; - size_t pool_nb[1] = { sizeof(float) }; - aclTensor* acl_value = ggml_cann_create_tensor( - *buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1); - aclnn_fill_scalar(ctx, 1, acl_value); - ggml_cann_release_resources(ctx, acl_value); - } + int64_t pool_ne[1] = { n_element }; + size_t pool_nb[1] = { ggml_type_size(dtype) }; + aclTensor* acl_value = ggml_cann_create_tensor( + *buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype), + pool_ne, pool_nb, 1); + aclnn_fill_scalar(ctx, value, acl_value); + ggml_cann_release_resources(ctx, acl_value); } - return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims); + return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype), + ggml_type_size(dtype), ne, nb, dims); } void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -965,35 +962,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - // build gamma, one... + // build gamma. size_t acl_gamma_nb[GGML_MAX_DIMS]; - acl_gamma_nb[0] = sizeof(float); + // gamma's type is the same with dst. + acl_gamma_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1]; } - aclTensor* acl_gamma = get_f32_cache_acl_tensor( + aclTensor* acl_gamma = get_cache_acl_tensor( ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, src->ne, acl_gamma_nb, + dst->type, 1, // dims 1.0f // value ); - // build rstd, zero... + // build rstd. int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]}; size_t acl_rstd_nb[GGML_MAX_DIMS - 1]; + // rstd will always be F32. acl_rstd_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1]; } - aclTensor* acl_rstd = get_f32_cache_acl_tensor( + aclTensor* acl_rstd = get_cache_acl_tensor( ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size, acl_rstd_ne, acl_rstd_nb, + GGML_TYPE_F32, GGML_MAX_DIMS - 1, 0.0f // value ); @@ -1765,33 +1766,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; // src ggml_tensor* src1 = dst->src[1]; // index + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + switch (src0->type) { - case GGML_TYPE_F32: { - aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); - break; - } - case GGML_TYPE_F16: { - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float)); - void* src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(float); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + case GGML_TYPE_F16: + case GGML_TYPE_F32: + if(src0->type == dst->type) { + aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, + dst->data, dst->ne, dst->nb, + src1, dst->type); + } else { + aclTensor* acl_src0 = ggml_cann_create_tensor(src0); + ggml_cann_pool_alloc src_buffer_allocator( + ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); + void* src_trans_buffer = src_buffer_allocator.get(); + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = dst->nb[0]; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } + aclTensor* src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_trans_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); + aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, + dst->data, dst->ne, dst->nb, + src1, dst->type); + ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); } - aclTensor* src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type), - src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); - aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); - ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); break; - } case GGML_TYPE_Q8_0: { // add 1 dim for bcast mul. size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], @@ -1799,7 +1802,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne; int64_t scale_offset = 0; - // [3,4,5,64] -> [3,4,5,2,32] weight_ne[0] = QK8_0; weight_ne[1] = src0->ne[0] / QK8_0; @@ -1809,7 +1811,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { weight_ne[i] = src0->ne[i - 1]; weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,1] scale_ne[0] = 1; scale_ne[1] = src0->ne[0] / QK8_0; @@ -1819,18 +1820,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { scale_ne[i] = src0->ne[i - 1]; scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,32] dequant_ne = weight_ne; - dequant_nb[0] = sizeof(float); + dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; } - scale_offset = ggml_nelements(src0) * sizeof(int8_t); ggml_cann_pool_alloc dequant_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float)); - + ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type)); aclTensor* acl_weight_tensor = ggml_cann_create_tensor( src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, GGML_MAX_DIMS + 1); @@ -1838,22 +1836,20 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); aclTensor* dequant_tensor = ggml_cann_create_tensor( - dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float), + dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); - aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); - dequant_nb[0] = sizeof(float); + dequant_nb[0] = ggml_type_size(dst->type); dequant_ne = src0->ne; for (int i = 1; i < GGML_MAX_DIMS; i++) { dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; } - aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne, dst->nb, src1, dst->type); - ggml_cann_release_resources(ctx, dequant_tensor); + ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); break; } default: @@ -1965,16 +1961,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, // Only check env once. static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); if (weight_to_nz && is_matmul_weight(weight)) { - int64_t acl_stride[2] = {1, transpose_ne[1]}; - - // Reverse ne. - std::reverse(transpose_ne, transpose_ne + n_dims); - - std::vector storageDims = {transpose_ne[0], transpose_ne[1]}; - - acl_weight_tensor = aclCreateTensor( - transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride, - 0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data); + acl_weight_tensor = + ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); @@ -3178,7 +3166,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* acl_src0_f16_tensor = nullptr; aclTensor* acl_src1_f16_tensor = nullptr; aclTensor* acl_src2_f16_tensor = nullptr; - aclTensor* acl_dst_f16_tensor = nullptr; // Step 1: cast the src0 (Query) to fp16 if needed ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); @@ -3216,22 +3203,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); - ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - void* out_f16_buffer = out_f16_allocator.alloc( - ggml_nelements(dst) * faElemSize); - - int64_t* out_f16_ne = src0_bsnd_ne; - size_t out_f16_nb[GGML_MAX_DIMS]; - out_f16_nb[0] = faElemSize; - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; - } - - acl_dst_f16_tensor = ggml_cann_create_tensor( - out_f16_buffer, faDataType, faElemSize, - out_f16_ne, out_f16_nb, GGML_MAX_DIMS - ); - // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp aclTensor* bcast_pse_tensor = nullptr; @@ -3317,8 +3288,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* acl_q_tensor = acl_src0_f16_tensor; aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; - auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); - auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); + aclTensorList* acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); + aclTensorList* acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); int64_t numHeads = src0->ne[2]; // N int64_t numKeyValueHeads = src1->ne[2]; @@ -3334,8 +3305,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ int64_t keyAntiquantMode = 0; int64_t valueAntiquantMode = 0; - // Step 5: launch the FusedInferAttentionScoreV2 kernel. - // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + aclTensor * fa_dst_tensor = nullptr; + aclTensor * acl_dst_tensor = nullptr; + ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); + if (dst->type == GGML_TYPE_F32) { + void* out_f16_buffer = out_f16_allocator.alloc( + ggml_nelements(dst) * faElemSize); + + int64_t* out_f16_ne = src0_bsnd_ne; + size_t out_f16_nb[GGML_MAX_DIMS]; + out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; + } + + fa_dst_tensor = ggml_cann_create_tensor( + out_f16_buffer, faDataType, faElemSize, + out_f16_ne, out_f16_nb, GGML_MAX_DIMS + ); + } + else { + fa_dst_tensor = ggml_cann_create_tensor(dst); + } GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v @@ -3357,23 +3349,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ blockSize, antiquantMode, // blockSize, antiquantMode softmaxLseFlag, // softmaxLseFlag keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode - acl_dst_f16_tensor, // attentionOut + fa_dst_tensor, // attentionOut nullptr // softmaxLse ); - // Step 6: post-processing, permute and cast to f32 - aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - // TODO: when dst is fp16, don't need cast - aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); - ggml_cann_release_resources(ctx, acl_src0_f16_tensor, - acl_src1_f16_tensor, - acl_src2_f16_tensor, - acl_dst_f16_tensor, - acl_dst_tensor); - if(src3 != nullptr){ - ggml_cann_release_resources(ctx, bcast_pse_tensor); + if (dst->type == GGML_TYPE_F32) { + // Step 6: post-processing, permute and cast to f32 + aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); } - }else{ + + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, + acl_k_tensor_list, + acl_v_tensor_list, + fa_dst_tensor, + acl_dst_tensor, + bcast_pse_tensor); + + } else { GGML_ABORT("Function is not implemented."); } } diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 799e2b1187..713bf85e5a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -68,7 +68,7 @@ struct ggml_compute_params { #endif // __VXE2__ #endif // __s390x__ && __VEC__ -#if defined(__ARM_FEATURE_SVE) +#if defined(__ARM_FEATURE_SVE) && defined(__linux__) #include #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index eded6eb77e..ba2a36d999 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -689,8 +689,13 @@ bool ggml_is_numa(void) { #endif static void ggml_init_arm_arch_features(void) { -#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE) +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) +#if defined(__linux__) ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#else + // TODO: add support of SVE for non-linux systems +#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here." +#endif #endif } diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index b8e37052d3..43dc7537c3 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -463,9 +463,9 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa #endif for (; i < n; ++i) { float val = x[i] - mean; + y[i] = val; val *= val; sum += (ggml_float)val; - y[i] = val; } return sum/n; } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index f6ec263feb..4eb90cbbb3 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -941,13 +941,6 @@ struct ggml_cuda_graph { std::vector nodes; std::vector params; std::vector ggml_graph_properties; - bool use_cpy_indirection = false; - std::vector cpy_dest_ptrs; - char ** dest_ptrs_d; - int dest_ptrs_size = 0; - // Index to allow each cpy kernel to be aware of it's position within the graph - // relative to other cpy nodes. - int graph_cpynode_index = -1; #endif }; diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index afea5ff8c5..12d5bf7763 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -8,18 +8,16 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst); template -static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int nb12, const int nb13) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; } - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // then combine those indices with the corresponding byte offsets to get the total offsets const int64_t i03 = i/(ne00 * ne01 * ne02); @@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int nb12, const int nb13) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int } template -static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int nb12, const int nb13) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int cpy_blck(cx + x_offset, cdst + dst_offset); } -// Copy destination pointers to GPU to be available when pointer indirection is in use - -void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) { -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers - CUDA_CHECK(cudaStreamSynchronize(stream)); - if (cuda_graph->dest_ptrs_d != nullptr) { - CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d)); - } - CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *))); - cuda_graph->dest_ptrs_size = host_dest_ptrs_size; - } - // copy destination pointers to GPU - CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream)); - cuda_graph->graph_cpynode_index = 0; // reset index -#else - GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream); -#endif -} - template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_flt><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q8_0_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_0 == 0); const int num_blocks = ne / QK4_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_0_f32_cuda( @@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_1 == 0); const int num_blocks = ne / QK4_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_1_f32_cuda( @@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK4_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_0 == 0); const int num_blocks = ne / QK5_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_0_f32_cuda( @@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK5_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_1 == 0); const int num_blocks = ne / QK5_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_1_f32_cuda( @@ -259,25 +233,25 @@ static void ggml_cpy_q5_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK5_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_NL == 0); const int num_blocks = ne / QK4_NL; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -311,16 +285,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; - char ** dest_ptrs_d = nullptr; - int graph_cpynode_index = -1; -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if (cuda_graph && cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { - dest_ptrs_d = cuda_graph->dest_ptrs_d; - graph_cpynode_index = cuda_graph->graph_cpynode_index; - } -#else - GGML_UNUSED(disable_indirection_for_this_node); -#endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) @@ -329,134 +293,62 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - if (src0->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); - } else { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); - } + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); } -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if (cuda_graph && cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { - cuda_graph->graph_cpynode_index = graph_cpynode_index; - } -#else - GGML_UNUSED(disable_indirection_for_this_node); -#endif - } -void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst) { +void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - bool disable_indirection = true; - ggml_cuda_cpy(ctx, cuda_graph, src0, dst, disable_indirection); -} - -void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { - if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - // Prioritize CUDA graph compatibility over direct memory copy optimization. - // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. - if (src0->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else { - return nullptr; - } - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK4_0>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK4_1>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK5_0>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK5_1>; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else { - GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - } + ggml_cuda_cpy(ctx, src0, dst); } diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 8b2774665f..a7a87d8fcf 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -2,10 +2,6 @@ #define CUDA_CPY_BLOCK_SIZE 64 -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); -void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst); - -void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); - -void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream); +void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 2efc9cc880..2b60b3bb13 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter( KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); } - KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ? - slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; + if (!oob_check || i_KQ < k_VKQ_sup) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ? + slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; - KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } } KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]); @@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float KQ_sum_add = 0.0f; #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { - const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]); - if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) { - KQ_sum_add += val; - } + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? + expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; + KQ_sum_add += val; tmp[i0/(np*warp_size)][jc1] = val; } KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; @@ -975,26 +976,6 @@ static __global__ void flash_attn_tile( } } - if (gridDim.y == 1) { -#pragma unroll - for (int jc0 = 0; jc0 < cpw; ++jc0) { -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]); -#pragma unroll - for (int i = 0; i < (DVp/2)/warp_size; ++i) { - VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv; - } -#else - const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0]; -#pragma unroll - for (int i = 0; i < (DVp/2)/warp_size; ++i) { - VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv; - VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv; - } -#endif // FAST_FP16_AVAILABLE - } - } - // Write back results: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { @@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile( return; } + const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; + const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; #ifdef FAST_FP16_AVAILABLE @@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); + tmp[i1].x *= scale; + tmp[i1].y *= scale; } if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { ggml_cuda_memcpy_1(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); @@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale; + } ggml_cuda_memcpy_1( &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 89ab0f1638..e1838fdded 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); const int nwarps = nthreads / WARP_SIZE; fattn_kernel_t fattn_kernel = flash_attn_ext_vec; - constexpr bool need_f16_K = false; - constexpr bool need_f16_V = false; + const bool need_f16_K = type_K == GGML_TYPE_F16; + const bool need_f16_V = type_V == GGML_TYPE_F16; constexpr size_t nbytes_shared = 0; launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } @@ -526,11 +526,6 @@ template void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - GGML_ASSERT(K->type == type_K); - GGML_ASSERT(V->type == type_V); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index fe970adaec..7dee032c29 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg } } -#define FATTN_VEC_CASE(D, type_K, type_V) \ - if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_case(ctx, dst); \ - return; \ - } \ +#define FATTN_VEC_CASE(D, type_K, type_V) \ + { \ + const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \ + const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \ + if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \ + ggml_cuda_flash_attn_ext_vec_case(ctx, dst); \ + return; \ + } \ + } \ #define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ FATTN_VEC_CASE( 64, type_K, type_V) \ @@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS switch (K->type) { + case GGML_TYPE_F32: case GGML_TYPE_F16: break; case GGML_TYPE_Q4_1: @@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // If Turing tensor cores available, use them: if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { if (can_use_vector_kernel) { - if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { return BEST_FATTN_KERNEL_VEC; } @@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { - if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (Q->ne[1] == 1) { if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_VEC; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 099763f5ea..5efd3e5f4e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2260,13 +2260,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, ggml_cuda ggml_cuda_op_set_rows(ctx, dst); break; case GGML_OP_DUP: - ggml_cuda_dup(ctx, cuda_graph, dst); + ggml_cuda_dup(ctx, dst); break; case GGML_OP_CPY: - ggml_cuda_cpy(ctx, cuda_graph, dst->src[0], dst->src[1]); + ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]); break; case GGML_OP_CONT: - ggml_cuda_dup(ctx, cuda_graph, dst); + ggml_cuda_dup(ctx, dst); break; case GGML_OP_ADD: case GGML_OP_ADD1: // TODO: more efficient implementation @@ -2633,11 +2633,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph, cudaStream_t stream, +static bool check_node_graph_compatibility(const ggml_cgraph * cgraph, bool use_cuda_graph) { // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_graph->cpy_dest_ptrs.clear(); const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; @@ -2688,33 +2687,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_cuda_graph #endif } - if (node->op == GGML_OP_CPY) { - - // Store the pointers which are updated for each token, such that these can be sent - // to the device and accessed using indirection from CUDA graph - cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data); - - // store a pointer to each copy op CUDA kernel to identify it later - void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); - if (!ptr) { - use_cuda_graph = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); -#endif - } - } - if (!use_cuda_graph) { break; } } - if (use_cuda_graph) { - cuda_graph->use_cpy_indirection = true; - // copy pointers to GPU so they can be accessed via indirection within CUDA graph - ggml_cuda_cpy_dest_ptrs_copy(cuda_graph, cuda_graph->cpy_dest_ptrs.data(), cuda_graph->cpy_dest_ptrs.size(), stream); - } - return use_cuda_graph; } @@ -2733,7 +2710,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { if (node->data != graph_node_properties->node_address && - node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW) { return false; } @@ -2754,7 +2730,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->src[i] && node->src[i]->data != graph_node_properties->src_address[i] && - node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW ) { return false; @@ -2901,7 +2876,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } //if rms norm is the B operand, then we don't handle broadcast - if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { return false; } @@ -3105,13 +3080,11 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen } if (use_cuda_graph) { - use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_graph, cgraph, cuda_ctx->stream(), use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); } if (use_cuda_graph) { capture_cuda_graph(cuda_ctx, cuda_graph, cgraph); - } else { - cuda_graph->use_cpy_indirection = false; } return cuda_graph; @@ -3138,7 +3111,7 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac bool use_cuda_graph = true; - use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_graph, cgraph, cuda_ctx->stream(), use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); if (!use_cuda_graph) { if (cuda_graph->instance != nullptr) { @@ -3149,7 +3122,6 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac } cuda_graph->instance = nullptr; cuda_graph->graph = nullptr; - cuda_graph->use_cpy_indirection = false; } if (is_cuda_graph_update_required(cuda_graph, cgraph)) { diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 599e085ee9..9e2aaf52d6 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -1,5 +1,7 @@ #include "ggml.h" #include "mmf.cuh" +#include "mmid.cuh" + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); @@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + mmf_ids_data ids_info{}; + mmf_ids_data * ids_info_ptr = nullptr; + ggml_cuda_pool_alloc ids_src_compact_dev; + ggml_cuda_pool_alloc ids_dst_compact_dev; + ggml_cuda_pool_alloc expert_bounds_dev; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_dst = ids ? ne1 : ne2; @@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr nchannels_y = ids->ne[0]; } + if (ids && ncols_dst > 16) { + const int64_t n_expert_used = ids->ne[0]; + const int64_t n_experts = ne02; + const int64_t n_tokens = ne12; + const int64_t ne_get_rows = n_tokens * n_expert_used; + + ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows); + ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows); + expert_bounds_dev.alloc(ctx.pool(), n_experts + 1); + + const int si1 = static_cast(ids_s1); + const int sis1 = static_cast(src1->nb[2] / src1->nb[1]); + + GGML_ASSERT(sis1 > 0); + + ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), + static_cast(n_experts), static_cast(n_tokens), static_cast(n_expert_used), static_cast(ne11), si1, sis1, ctx.stream()); + CUDA_CHECK(cudaGetLastError()); + + ids_info.ids_src_compact = ids_src_compact_dev.get(); + ids_info.ids_dst_compact = ids_dst_compact_dev.get(); + ids_info.expert_bounds_dev = expert_bounds_dev.get(); + ids_info.n_experts = static_cast(n_experts); + ids_info.sis1 = sis1; + ids_info_ptr = &ids_info; + } + switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; @@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; @@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; @@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const } if (mul_mat_id) { - if (type == GGML_TYPE_F32 && src1_ncols > 32) { + if (src0_ne[1] <= 1024 && src1_ncols > 512) { return false; - } - if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) { + } else if(src0_ne[1] > 1024 && src1_ncols > 128) { return false; } } else { diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index a6c3adfcf1..49d5295be0 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -7,6 +7,14 @@ using namespace ggml_cuda_mma; #define MMF_ROWS_PER_BLOCK 32 +struct mmf_ids_data { + const int32_t * ids_src_compact = nullptr; + const int32_t * ids_dst_compact = nullptr; + const int32_t * expert_bounds_dev = nullptr; + int n_experts = 0; + int sis1 = 0; +}; + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); @@ -224,6 +232,250 @@ static __global__ void mul_mat_f( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } + +//This kernel is for larger batch sizes of mul_mat_id +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f_ids( + const T * __restrict__ x, const float * __restrict__ y, + const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 sis1_fd, const uint3 nch_fd) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + const int expert_idx = blockIdx.y; + const int expert_start = expert_bounds[expert_idx]; + const int expert_end = expert_bounds[expert_idx + 1]; + const int ncols_expert = expert_end - expert_start; + + const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block; + const int tile_idx = blockIdx.z; + if (tile_idx >= tiles_for_expert) { + return; + } + + const int col_base = tile_idx * cols_per_block; + + GGML_UNUSED(channel_ratio); + + const int channel_x = expert_idx; + const int sample_dst = 0; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row; + y += int64_t(sample_y) *stride_sample_y; + dst += int64_t(sample_dst)*stride_sample_dst; + + const int32_t * ids_src_expert = ids_src_compact + expert_start; + const int32_t * ids_dst_expert = ids_dst_compact + expert_start; + + extern __shared__ char data_mmv[]; + char * compute_base = data_mmv; + + //const float2 * y2 = (const float2 *) y; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + + if constexpr (std::is_same_v) { + float vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float *vals) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + tile_idx_local*tile_B::I; + const int global_j = col_base + j; + float val = 0.0f; + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + val = y[channel*stride_channel_y + token*stride_col_y + col]; + } + } + vals[j0] = val; + } + }; + + gather_tile(0, vals_buf[0]); + + int curr_buf = 0; + int next_buf = 1; +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0]; + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { + float2 vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float2 *vals) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + tile_idx_local*tile_B::I; + const int global_j = col_base + j; + float2 tmp = make_float2(0.0f, 0.0f); + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)]; + } + } + vals[j0] = tmp; + } + }; + + if (ntB > 0) { + gather_tile(0, vals_buf[0]); + } + + int curr_buf = 0; + int next_buf = 1; +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const float2 tmp = vals_buf[curr_buf][j0]; + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; + } + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + + const int global_j = col_base + j; + if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) { + const int dst_entry = ids_dst_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd); + const int token = (int) qrm.x; + if (token < ncols_dst_total) { + const int slot = (int) qrm.y; + dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, @@ -232,13 +484,35 @@ static inline void mul_mat_f_switch_ids( const int64_t stride_col_id, const int64_t stride_row_id, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { - if (ids) { + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream, + const mmf_ids_data * ids_data) { + const bool has_ids_data = ids_data && ids_data->ids_src_compact; + + // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16) + // we prefer the normal mul_mat_f path with has_ids=true. + if (has_ids_data && ncols_dst > 16) { + const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block); + if (max_tiles == 0) { + return; + } + dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles); + + const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); + const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); + + mul_mat_f_ids<<>> + (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, + ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, + sis1_fd, nch_fd); + } else if (ids) { const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; dim3 block_nums_ids = block_nums; block_nums_ids.y *= col_tiles; + mul_mat_f<<>> - (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { @@ -258,7 +532,7 @@ void mul_mat_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + cudaStream_t stream, const mmf_ids_data * ids_data) { typedef tile<16, 8, T> tile_A; typedef tile< 8, 8, T> tile_B; @@ -290,7 +564,7 @@ void mul_mat_f_cuda( const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; - const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); const dim3 block_dims(warp_size, nwarps_best, 1); @@ -300,49 +574,57 @@ void mul_mat_f_cuda( mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 2: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 3: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 4: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 5: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 6: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 7: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 8: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; default: { GGML_ABORT("fatal error"); @@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + cudaStream_t stream, const mmf_ids_data * ids_data) { const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; @@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block( case 1: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 2: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 3: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 4: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 5: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 6: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 7: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 8: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 9: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 10: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 11: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 12: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 13: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 14: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 15: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 16: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; default: { GGML_ABORT("fatal error"); @@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ - cudaStream_t stream); + cudaStream_t stream, const mmf_ids_data * ids_data); #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #define DECL_MMF_CASE_EXTERN(ncols_dst) \ diff --git a/ggml/src/ggml-cuda/mmid.cu b/ggml/src/ggml-cuda/mmid.cu new file mode 100644 index 0000000000..3c61e4595a --- /dev/null +++ b/ggml/src/ggml-cuda/mmid.cu @@ -0,0 +1,164 @@ +#include "common.cuh" +#include "mmid.cuh" + +// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. +struct mm_ids_helper_store { + uint32_t data; + + __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) { + data = (it & 0x003FFFFF) | (iex_used << 22); + } + + __device__ uint32_t it() const { + return data & 0x003FFFFF; + } + + __device__ uint32_t iex_used() const { + return data >> 22; + } +}; +static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store"); + +// Helper function for mul_mat_id, converts ids to a more convenient format. +// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. +// ids_dst describes the same mapping but for the dst tensor. +// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. +template +__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; + const int expert = blockIdx.x; + + extern __shared__ char data_mm_ids_helper[]; + mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper; + + int nex_prev = 0; // Number of columns for experts with a lower index. + int it_compact = 0; // Running index for the compact slice of this expert. + + if constexpr (n_expert_used_template == 0) { + // Generic implementation: + for (int it = 0; it < n_tokens; ++it) { + int iex_used = -1; // The index at which the expert is used, if any. + for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { + const int expert_used = ids[it*si1 + iex]; + nex_prev += expert_used < expert; + if (expert_used == expert) { + iex_used = iex; + } + } + + if (iex_used != -1) { + store[it_compact] = mm_ids_helper_store(it, iex_used); + } + + if (warp_reduce_any(iex_used != -1)) { + it_compact++; + } + } + } else { + // Implementation optimized for specific numbers of experts used: + static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); + const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. + for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { + const int it = it0 + threadIdx.x / neu_padded; + + const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. + const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? + ids[it*si1 + iex] : INT_MAX; + const int iex_used = expert_used == expert ? iex : -1; + nex_prev += expert_used < expert; + + // Whether the threads at this token position have used the expert: + const int it_compact_add_self = warp_reduce_any(iex_used != -1); + + // Do a scan over threads at lower token positions in warp to get the correct index for writing data: + int it_compact_add_lower = 0; +#pragma unroll + for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { + const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); + if (threadIdx.x >= static_cast(offset)) { + it_compact_add_lower += tmp; + } + } + + if (iex_used != -1) { + store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used); + } + + // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: + it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); + } + } + nex_prev = warp_reduce_sum(nex_prev); + + for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { + const mm_ids_helper_store store_it = store[itc]; + const int it = store_it.it(); + const int iex_used = store_it.iex_used(); + ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } + + if (threadIdx.x != 0) { + return; + } + + expert_bounds[expert] = nex_prev; + + if (expert < static_cast(gridDim.x) - 1) { + return; + } + + expert_bounds[gridDim.x] = nex_prev + it_compact; +} + +template +static void launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store"); + GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store"); + + const int id = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper, smpbo); + + const dim3 num_blocks(n_experts, 1, 1); + const dim3 block_size(warp_size, 1, 1); + const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store); + GGML_ASSERT(nbytes_shared <= smpbo); + mm_ids_helper<<>> + (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); +} + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + switch (n_expert_used) { + case 2: + launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 4: + launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 6: + launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 8: + launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 16: + launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 32: + launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + default: + launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + } +} diff --git a/ggml/src/ggml-cuda/mmid.cuh b/ggml/src/ggml-cuda/mmid.cuh new file mode 100644 index 0000000000..ac090aea9e --- /dev/null +++ b/ggml/src/ggml-cuda/mmid.cuh @@ -0,0 +1,5 @@ +#pragma once + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 12bdc629bd..a2c8760abe 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,141 +1,6 @@ #include "mmq.cuh" #include "quantize.cuh" - -#include - -// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. -struct mmq_ids_helper_store { - uint32_t data; - - __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) { - data = (it & 0x003FFFFF) | (iex_used << 22); - } - - __device__ uint32_t it() const { - return data & 0x003FFFFF; - } - - __device__ uint32_t iex_used() const { - return data >> 22; - } -}; -static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); - -// Helper function for mul_mat_id, converts ids to a more convenient format. -// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. -// ids_dst describes the same mapping but for the dst tensor. -// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. -template -__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) -static __global__ void mmq_ids_helper( - const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, - const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; - const int expert = blockIdx.x; - - extern __shared__ char data_mmq_ids_helper[]; - mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper; - - int nex_prev = 0; // Number of columns for experts with a lower index. - int it_compact = 0; // Running index for the compact slice of this expert. - - if constexpr (n_expert_used_template == 0) { - // Generic implementation: - for (int it = 0; it < n_tokens; ++it) { - int iex_used = -1; // The index at which the expert is used, if any. - for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { - const int expert_used = ids[it*si1 + iex]; - nex_prev += expert_used < expert; - if (expert_used == expert) { - iex_used = iex; - } - } - - if (iex_used != -1) { - store[it_compact] = mmq_ids_helper_store(it, iex_used); - } - - if (warp_reduce_any(iex_used != -1)) { - it_compact++; - } - } - } else { - // Implementation optimized for specific numbers of experts used: - static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); - const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. - for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { - const int it = it0 + threadIdx.x / neu_padded; - - const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. - const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? - ids[it*si1 + iex] : INT_MAX; - const int iex_used = expert_used == expert ? iex : -1; - nex_prev += expert_used < expert; - - // Whether the threads at this token position have used the expert: - const int it_compact_add_self = warp_reduce_any(iex_used != -1); - - // Do a scan over threads at lower token positions in warp to get the correct index for writing data: - int it_compact_add_lower = 0; -#pragma unroll - for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { - const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); - if (threadIdx.x >= static_cast(offset)) { - it_compact_add_lower += tmp; - } - } - - if (iex_used != -1) { - store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used); - } - - // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: - it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); - } - } - nex_prev = warp_reduce_sum(nex_prev); - - for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { - const mmq_ids_helper_store store_it = store[itc]; - const int it = store_it.it(); - const int iex_used = store_it.iex_used(); - ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; - ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; - } - - if (threadIdx.x != 0) { - return; - } - - expert_bounds[expert] = nex_prev; - - if (expert < static_cast(gridDim.x) - 1) { - return; - } - - expert_bounds[gridDim.x] = nex_prev + it_compact; -} - -template -static void launch_mmq_ids_helper( - const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, - const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { - GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store"); - GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store"); - - const int id = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[id].warp_size; - const size_t smpbo = ggml_cuda_info().devices[id].smpbo; - CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo); - - const dim3 num_blocks(n_experts, 1, 1); - const dim3 block_size(warp_size, 1, 1); - const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); - GGML_ASSERT(nbytes_shared <= smpbo); - mmq_ids_helper<<>> - (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); -} +#include "mmid.cuh" static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { @@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q( const int si1 = ids->nb[1] / ggml_element_size(ids); const int sis1 = nb12 / nb11; - switch (n_expert_used) { - case 2: - launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 4: - launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 6: - launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 8: - launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 16: - launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 32: - launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - default: - launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - } + ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 5b21ef05b3..57ab839393 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -7,14 +7,14 @@ template static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const int row = blockIdx.x; const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); const int sample_y = sample_dst; const int tid = threadIdx.x; @@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x*tmpy.x; - sumf[j] += tmpx.y*tmpy.y; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } } else if constexpr (std::is_same_v) { @@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x * tmpy.x; - sumf[j] += tmpx.y * tmpy.y; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } } else { @@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f( #endif // FP16_AVAILABLE } } else if constexpr (std::is_same_v) { +//TODO: add support for ggml_cuda_mad for hip_bfloat162 +#if defined(GGML_USE_HIP) const int * x2 = (const int *) x; for (int col2 = tid; col2 < ncols2; col2 += block_size) { const int tmpx = x2[col2]; #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[0]) * tmpy.x; - sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[1]) * tmpy.y; + const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); + const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); + ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); } } +#else + const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const nv_bfloat162 tmpx = x2[col2]; +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + } + } +#endif } else { static_assert(std::is_same_v, "unsupported type"); } @@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda( GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; @@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda( case 32: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e23abdda97..866cd2da58 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_SUM); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); @@ -1482,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_OPT_STEP_ADAMW); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_OPT_STEP_SGD); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 1034e4bbf6..28ae2e1765 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); @@ -134,6 +135,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 9527973015..c3fe8f4e91 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_COS: case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: @@ -692,7 +693,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return true; case GGML_OP_FLASH_ATTN_EXT: // for new head sizes, add checks here - if (op->src[0]->ne[0] != 40 && + if (op->src[0]->ne[0] != 32 && + op->src[0]->ne[0] != 40 && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 80 && op->src[0]->ne[0] != 96 && @@ -798,6 +800,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return false; }; } + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + return has_simdgroup_reduction; default: return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index c9dff87305..a448c14f66 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -544,6 +544,10 @@ typedef struct{ float limit; } ggml_metal_kargs_glu; +typedef struct { + uint64_t np; +} ggml_metal_kargs_sum; + typedef struct { int64_t ne00; int64_t ne01; @@ -773,4 +777,12 @@ typedef struct { uint64_t nb01; } ggml_metal_kargs_argmax; +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_adamw; + +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_sgd; + #endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5f9370449b..a61ea8fb5a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_glu(ctx, idx); } break; + case GGML_OP_SUM: + { + n_fuse = ggml_metal_op_sum(ctx, idx); + } break; case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: { @@ -410,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_argmax(ctx, idx); } break; + case GGML_OP_OPT_STEP_ADAMW: + { + n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx); + } break; + case GGML_OP_OPT_STEP_SGD: + { + n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); + } break; default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); @@ -840,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const uint64_t n = (uint64_t) ggml_nelements(op->src[0]); + + ggml_metal_kargs_sum args = { + /*.np =*/ n, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + + return 1; +} + int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -3401,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { return 1; } + +int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_adamw args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_sgd args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index d4cb944621..f352738698 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); @@ -78,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); #ifdef __cplusplus } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ddc285042d..1029cf8f9a 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32( } } +kernel void kernel_op_sum_f32( + constant ggml_metal_kargs_sum & args, + device const float * src0, + device float * dst, + ushort tiitg[[thread_index_in_threadgroup]]) { + + if (tiitg != 0) { + return; + } + + float acc = 0.0f; + for (ulong i = 0; i < args.np; ++i) { + acc += src0[i]; + } + + dst[0] = acc; +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, @@ -5195,8 +5213,30 @@ kernel void kernel_flash_attn_ext( half, half4, simdgroup_half8x8 //float, float4, simdgroup_float8x8 +#define FA_TYPES_F32 \ + half, half4, simdgroup_half8x8, \ + float, float4x4, simdgroup_float8x8, \ + float, float4x4, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + float, float2, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 + typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; +template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5209,6 +5249,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5221,6 +5262,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif +template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5232,6 +5274,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5243,6 +5286,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5254,6 +5298,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5265,6 +5310,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5800,77 +5846,103 @@ kernel void kernel_flash_attn_ext_vec( float, float4, \ float4 +#define FA_TYPES_F32 \ + half4, \ + float4, \ + float4, \ + float, \ + float, float4, \ + float4 + typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES @@ -8754,3 +8826,51 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_opt_step_adamw_f32( + constant ggml_metal_kargs_opt_step_adamw & args, + device float * x, + device const float * g, + device float * g_m, + device float * g_v, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + + const float gi = g[gid]; + const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1); + const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2); + + g_m[gid] = gmi; + g_v[gid] = gvi; + + const float mh = gmi * beta1h; + const float vh = sqrt(gvi * beta2h) + eps; + + x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh; +} + +kernel void kernel_opt_step_sgd_f32( + constant ggml_metal_kargs_opt_step_sgd & args, + device float * x, + device const float * g, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; +} diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 79d2148744..0693d38d80 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2348,8 +2348,13 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); if (opencl_c_version.major >= 3) { + // Assume it is not available for 3.0, since it is optional in 3.0. + // If compiling against 3.0, then we can query. + backend_ctx->non_uniform_workgroups = false; +#if CL_TARGET_OPENCL_VERSION >= 300 CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool), &backend_ctx->non_uniform_workgroups, 0)); +#endif } else { GGML_ASSERT(opencl_c_version.major == 2); // Non-uniform workgroup sizes is mandatory feature in v2.x. @@ -2681,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx // if rms_norm is the B operand, then we don't handle broadcast if (rms_norm == mul->src[1] && - !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + !ggml_are_same_shape(mul->src[0], rms_norm)) { return false; } diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 83a83887b5..de01336cd3 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -1,9 +1,18 @@ cmake_minimum_required(VERSION 3.19) cmake_policy(SET CMP0114 NEW) cmake_policy(SET CMP0116 NEW) +if (POLICY CMP0147) + # Parallel build custom build steps + cmake_policy(SET CMP0147 NEW) +endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # Parallel build object files + add_definitions(/MP) +endif() + function(detect_host_compiler) if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3cd89c7116..1674dc66ab 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2649,11 +2649,13 @@ static void ggml_vk_load_shaders(vk_device& device) { } \ } + CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) @@ -2661,6 +2663,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { + CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) @@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); - const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); - const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + // For F32, the shader treats it as a block of size 4 (for vec4 loads) + if (k->type == GGML_TYPE_F32) { + k_stride /= 4; + } + if (v->type == GGML_TYPE_F32) { + v_stride /= 4; + } uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); bool aligned = (KV % alignment) == 0 && @@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } switch (op->src[1]->type) { case GGML_TYPE_F16: + case GGML_TYPE_F32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: // supported in scalar and coopmat2 paths diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 6a5bb4574d..67baedf7c6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,6 +1,18 @@ #include "types.glsl" +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 { + vec4 block; +}; + +float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const vec4 v = bl.block; + const uint idx = coordInBlock[1]; + const f16vec4 vf16 = f16vec4(v); + return vf16[idx]; +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_NL #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#elif defined(DATA_A_F32) +#define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 9b1f153bf7..eb93903c46 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; -#if defined(A_TYPE_PACKED16) #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 +#if defined(DATA_A_F32) +layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed; +layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed; +#elif defined(A_TYPE_PACKED16) layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#if defined(DATA_A_F32) +#undef BLOCK_SIZE +#define BLOCK_SIZE 4 +#define BLOCK_BYTE_SIZE 16 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + // iqs is currently always zero in the flash attention shaders + if (binding_idx == BINDING_IDX_K) { + return k_packed.k_data_packed[a_offset + ib]; + } else { + return v_packed.v_data_packed[a_offset + ib]; + } +} +#endif + #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 85400ac5fc..a20788c4b5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -313,12 +313,12 @@ void main() { sums[i] = coopmat(0.0f); } #else - ACC_TYPE sums[WMITER * TM * WNITER * TN]; + ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC2 cache_b[TN]; + FLOAT_TYPE_VEC2 cache_b; - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = ACC_TYPE(0.0f); + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); } #endif @@ -360,20 +360,22 @@ void main() { cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; } } - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; - } - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i]; + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; + sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); + sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); } } } } + } #endif @@ -388,8 +390,9 @@ void main() { } } #else - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX); + sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX); } #endif #endif @@ -463,14 +466,21 @@ void main() { const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID - [[unroll]] for (uint cr = 0; cr < TM; cr++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; #ifdef MUL_MAT_ID - if (dr_warp + cr < p.M) { - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); } #else - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); } #endif // MUL_MAT_ID } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index f0cc24ff31..184f3f3a7d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -611,9 +611,6 @@ void process_shaders() { } for (const auto& tname : type_names) { - if (tname == "f32") { - continue; - } if (tname == "bf16") continue; #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -630,7 +627,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0") { + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); @@ -639,7 +636,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0") { + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a24853c63a..f29a1e98c9 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } -static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { +static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); - const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" : - (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" : - (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" : - (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown"; + const char * swa_type_str = "unknown"; + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break; + case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break; + case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break; + case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break; + }; + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); @@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; - GGML_ASSERT(kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); + const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { + for (int h = 0; h < 1; ++h) { + for (int i1 = 0; i1 < n_tokens; ++i1) { + const llama_seq_id s1 = ubatch->seq_id[i1][0]; + const llama_pos p1 = ubatch->pos[i1]; - float * data = (float *) kq_mask->data; + const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv; - // [TAG_NO_CACHE_ISWA] - GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement"); - - for (int h = 0; h < 1; ++h) { - for (int i1 = 0; i1 < n_tokens; ++i1) { - const llama_seq_id s1 = ubatch->seq_id[i1][0]; - - for (int i0 = 0; i0 < n_tokens; ++i0) { - float f = -INFINITY; - - for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { + for (int i0 = 0; i0 < n_tokens; ++i0) { const llama_seq_id s0 = ubatch->seq_id[i0][0]; + const llama_pos p0 = ubatch->pos[i0]; + // mask different sequences if (s0 != s1) { - continue; // skip different sequences + continue; } - if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { - continue; // skip future tokens for causal attention + // mask future tokens + if (cparams.causal_attn && p0 > p1) { + continue; } - // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA] - //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { - // continue; // skip masked tokens for SWA - //} - - // TODO: reimplement this like in llama_kv_cache_unified - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); - } else { - f = 0.0f; + // apply SWA if any + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + continue; } + + data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } - data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; } } + }; + + { + GGML_ASSERT(self_kq_mask); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + + float * data = (float *) self_kq_mask->data; + + std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY); + + fill_mask(data, 0, LLAMA_SWA_TYPE_NONE); + + if (debug) { + print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE); + } } - if (debug) { - print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(self_kq_mask_swa); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); + + float * data = (float *) self_kq_mask_swa->data; + + std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY); + + fill_mask(data, hparams.n_swa, hparams.swa_type); + + if (debug) { + print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + } } } @@ -1299,12 +1321,9 @@ ggml_tensor * llm_graph_context::build_attn_mha( k = ggml_permute(ctx0, k, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3); - const auto n_kv = k->ne[1]; - ggml_tensor * cur; - // TODO: replace hardcoded padding with ggml-provided padding - if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) { + if (cparams.flash_attn && kq_b == nullptr) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1419,10 +1438,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con auto inp = std::make_unique(hparams, cparams); // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); - ggml_set_input(inp->kq_mask); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + ggml_set_input(inp->self_kq_mask); - inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + ggml_set_input(inp->self_kq_mask_swa); + + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + } else { + inp->self_kq_mask_swa = nullptr; + inp->self_kq_mask_swa_cnv = nullptr; + } return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); } @@ -1447,7 +1476,9 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto & kq_mask = inp->get_kq_mask(); + const bool is_swa = hparams.is_swa(il); + + const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams diff --git a/src/llama-graph.h b/src/llama-graph.h index dc84b79428..d0c3934f67 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -257,10 +257,14 @@ public: void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } - ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] - ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] + // n_tokens == n_batch + ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 36d495d6cf..0cdad9babd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11358,8 +11358,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { } }; -struct llm_build_gemma_embedding_iswa : public llm_graph_context { - llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +struct llm_build_gemma_embedding : public llm_graph_context { + llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -11376,8 +11376,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA] - auto * inp_attn = build_attn_inp_kv_iswa(); + auto * inp_attn = build_attn_inp_no_cache(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -19378,7 +19377,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: - //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA] + case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: @@ -19671,7 +19670,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GEMMA_EMBEDDING: { - llm = std::make_unique(*this, params); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_STARCODER2: { diff --git a/src/llama.cpp b/src/llama.cpp index fe5a7a8354..38700f97a0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -312,6 +312,7 @@ struct llama_model * llama_model_load_from_splits( LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__); return nullptr; } + splits.reserve(n_paths); for (size_t i = 0; i < n_paths; ++i) { splits.push_back(paths[i]); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2fa16b497a..14472dcf12 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6779,7 +6779,7 @@ static std::vector> make_test_cases_eval() { for (int nb : { 1, 3, 32, 35, }) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { test_cases.emplace_back(new test_flash_attn_ext( hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV)); // run fewer test cases permuted @@ -6911,7 +6911,7 @@ static std::vector> make_test_cases_perf() { } // qwen3-30b-a3b - for (int bs : {1, 4, 8, 32, 64, 128, 512}) { + for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); @@ -6919,6 +6919,15 @@ static std::vector> make_test_cases_perf() { } } + for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { + for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1)); + } + } + } + + // gpt-oss-20b for (int bs : {1, 4, 8, 512}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) { diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index d0c44534ef..c026f36c48 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index cf12805b49..77969d24e1 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1585,23 +1585,31 @@ struct server_prompt_cache { } } + // average size per token + const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); + + // dynamically increase the token limit if it can fit in the memory limit + const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size/size_per_token) : limit_tokens; + if (limit_tokens > 0) { - while (states.size() > 1 && n_tokens() > limit_tokens) { + while (states.size() > 1 && n_tokens() > limit_tokens_cur) { if (states.empty()) { break; } - SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", + limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); states.pop_front(); } } - SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n", - states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens); + SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); for (const auto & state : states) { - SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", + (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); } } }; diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index c300ecaa77..9cd6ef9138 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -50,6 +50,7 @@ "eslint-plugin-svelte": "^3.0.0", "fflate": "^0.8.2", "globals": "^16.0.0", + "mdast": "^3.0.0", "mdsvex": "^0.12.3", "playwright": "^1.53.0", "prettier": "^3.4.2", @@ -66,6 +67,7 @@ "tw-animate-css": "^1.3.5", "typescript": "^5.0.0", "typescript-eslint": "^8.20.0", + "unified": "^11.0.5", "uuid": "^13.0.0", "vite": "^7.0.4", "vite-plugin-devtools-json": "^0.2.0", @@ -2128,6 +2130,66 @@ "node": ">=14.0.0" } }, + "node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/core": { + "version": "1.4.3", + "dev": true, + "inBundle": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@emnapi/wasi-threads": "1.0.2", + "tslib": "^2.4.0" + } + }, + "node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/runtime": { + "version": "1.4.3", + "dev": true, + "inBundle": true, + "license": "MIT", + "optional": true, + "dependencies": { + "tslib": "^2.4.0" + } + }, + "node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/wasi-threads": { + "version": "1.0.2", + "dev": true, + "inBundle": true, + "license": "MIT", + "optional": true, + "dependencies": { + "tslib": "^2.4.0" + } + }, + "node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@napi-rs/wasm-runtime": { + "version": "0.2.11", + "dev": true, + "inBundle": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@emnapi/core": "^1.4.3", + "@emnapi/runtime": "^1.4.3", + "@tybys/wasm-util": "^0.9.0" + } + }, + "node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@tybys/wasm-util": { + "version": "0.9.0", + "dev": true, + "inBundle": true, + "license": "MIT", + "optional": true, + "dependencies": { + "tslib": "^2.4.0" + } + }, + "node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/tslib": { + "version": "2.8.0", + "dev": true, + "inBundle": true, + "license": "0BSD", + "optional": true + }, "node_modules/@tailwindcss/oxide-win32-arm64-msvc": { "version": "4.1.11", "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.11.tgz", @@ -4946,6 +5008,13 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/mdast": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/mdast/-/mdast-3.0.0.tgz", + "integrity": "sha512-xySmf8g4fPKMeC07jXGz971EkLbWAJ83s4US2Tj9lEdnZ142UP5grN73H1Xd3HzrdbU5o9GYYP/y8F9ZSwLE9g==", + "dev": true, + "license": "MIT" + }, "node_modules/mdast-util-find-and-replace": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/mdast-util-find-and-replace/-/mdast-util-find-and-replace-3.0.2.tgz", diff --git a/tools/server/webui/package.json b/tools/server/webui/package.json index 7bf21bf57c..e073cd32f0 100644 --- a/tools/server/webui/package.json +++ b/tools/server/webui/package.json @@ -52,6 +52,7 @@ "eslint-plugin-svelte": "^3.0.0", "fflate": "^0.8.2", "globals": "^16.0.0", + "mdast": "^3.0.0", "mdsvex": "^0.12.3", "playwright": "^1.53.0", "prettier": "^3.4.2", @@ -68,6 +69,7 @@ "tw-animate-css": "^1.3.5", "typescript": "^5.0.0", "typescript-eslint": "^8.20.0", + "unified": "^11.0.5", "uuid": "^13.0.0", "vite": "^7.0.4", "vite-plugin-devtools-json": "^0.2.0", diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte index 666febf0d2..374eb05ab0 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte @@ -7,6 +7,7 @@ ChatMessages, ChatProcessingInfo, EmptyFileAlertDialog, + ChatErrorDialog, ServerErrorSplash, ServerInfo, ServerLoadingSplash, @@ -22,10 +23,11 @@ activeMessages, activeConversation, deleteConversation, + dismissErrorDialog, + errorDialog, isLoading, sendMessage, - stopGeneration, - setMaxContextError + stopGeneration } from '$lib/stores/chat.svelte'; import { supportsVision, @@ -34,7 +36,6 @@ serverWarning, serverStore } from '$lib/stores/server.svelte'; - import { contextService } from '$lib/services'; import { parseFilesToMessageExtras } from '$lib/utils/convert-files-to-extra'; import { isFileTypeSupported } from '$lib/utils/file-type'; import { filterFilesByModalities } from '$lib/utils/modality-file-validation'; @@ -79,6 +80,7 @@ showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading() ); + let activeErrorDialog = $derived(errorDialog()); let isServerLoading = $derived(serverLoading()); async function handleDeleteConfirm() { @@ -105,6 +107,12 @@ } } + function handleErrorDialogOpenChange(open: boolean) { + if (!open) { + dismissErrorDialog(); + } + } + function handleDragOver(event: DragEvent) { event.preventDefault(); } @@ -183,21 +191,6 @@ const extras = result?.extras; - // Check context limit using real-time slots data - const contextCheck = await contextService.checkContextLimit(); - - if (contextCheck && contextCheck.wouldExceed) { - const errorMessage = contextService.getContextErrorMessage(contextCheck); - - setMaxContextError({ - message: errorMessage, - estimatedTokens: contextCheck.currentUsage, - maxContext: contextCheck.maxContext - }); - - return false; - } - // Enable autoscroll for user-initiated message sending userScrolledUp = false; autoScrollEnabled = true; @@ -461,6 +454,13 @@ }} /> + +