Merge remote-tracking branch 'upstream' into cuda_graph_plan

This commit is contained in:
Xiangyan Sun 2025-10-14 10:40:18 -07:00
commit c17f8b5bde
53 changed files with 1709 additions and 1102 deletions

View File

@ -146,9 +146,7 @@ void ggml_cann_op_unary_gated(
unary_op(ctx, acl_src0, acl_dst); unary_op(ctx, acl_src0, acl_dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
ggml_cann_release_resources(ctx, acl_src0, acl_dst); ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
if(src1)
ggml_cann_release_resources(ctx, acl_src1);
} }
/** /**
@ -894,14 +892,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
} }
/** /**
* @brief Get or expand a cached float32 tensor filled with a scalar value. * @brief Get or expand a cached tensor filled with a scalar value.
* *
* This function manages cached device memory for float32 tensors. If the current * This function manages cached device memory for tensors. If the current
* cache size is insufficient for the requested tensor shape, the old memory will * cache size is insufficient for the requested tensor shape, the old memory will
* be released and new memory will be allocated. The allocated buffer is then * be released and new memory will be allocated. The allocated buffer is
* initialized either with zeros (when @p value == 0.0f) or with the given scalar * initialized with the given scalar value using CANN operations.
* value using CANN operations. Finally, an aclTensor object is created from the * Finally, an aclTensor object is created from the cached memory and returned.
* cached memory and returned.
* *
* @param ctx The CANN backend context that manages device memory. * @param ctx The CANN backend context that manages device memory.
* @param buffer A pointer to the cached device buffer (will be allocated * @param buffer A pointer to the cached device buffer (will be allocated
@ -910,17 +907,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
* updated when the cache is expanded. * updated when the cache is expanded.
* @param ne The tensor shape array (number of elements in each dimension). * @param ne The tensor shape array (number of elements in each dimension).
* @param nb The stride size for each dimension. * @param nb The stride size for each dimension.
* @param dtype Data type of cached tensor.
* @param dims The number of tensor dimensions. * @param dims The number of tensor dimensions.
* @param value The scalar value used to fill the tensor (supports zero * @param value The scalar value used to fill the tensor (supports zero
* initialization via memset or arbitrary values via fill_scalar). * initialization via memset or arbitrary values via fill_scalar).
* @return An aclTensor pointer created from the cached buffer. * @return An aclTensor pointer created from the cached buffer.
*/ */
static aclTensor* get_f32_cache_acl_tensor( static aclTensor* get_cache_acl_tensor(
ggml_backend_cann_context& ctx, ggml_backend_cann_context& ctx,
void** buffer, void** buffer,
int64_t &cache_element, int64_t &cache_element,
int64_t* ne, int64_t* ne,
size_t* nb, size_t* nb,
ggml_type dtype,
int64_t dims, int64_t dims,
float value) { float value) {
// Calculate total number of elements // Calculate total number of elements
@ -928,7 +927,7 @@ static aclTensor* get_f32_cache_acl_tensor(
for (int i = 0; i < dims; i++) { for (int i = 0; i < dims; i++) {
n_element *= ne[i]; n_element *= ne[i];
} }
size_t size = n_element * sizeof(float); size_t size = n_element * ggml_type_size(dtype);
// Allocate or expand cache if needed // Allocate or expand cache if needed
if (cache_element < n_element) { if (cache_element < n_element) {
@ -941,19 +940,17 @@ static aclTensor* get_f32_cache_acl_tensor(
cache_element = n_element; cache_element = n_element;
// Initialize cache // Initialize cache
if (value == 0.0f) {
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
} else {
int64_t pool_ne[1] = { n_element }; int64_t pool_ne[1] = { n_element };
size_t pool_nb[1] = { sizeof(float) }; size_t pool_nb[1] = { ggml_type_size(dtype) };
aclTensor* acl_value = ggml_cann_create_tensor( aclTensor* acl_value = ggml_cann_create_tensor(
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1); *buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype),
aclnn_fill_scalar(ctx, 1, acl_value); pool_ne, pool_nb, 1);
aclnn_fill_scalar(ctx, value, acl_value);
ggml_cann_release_resources(ctx, acl_value); ggml_cann_release_resources(ctx, acl_value);
} }
}
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims); return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype),
ggml_type_size(dtype), ne, nb, dims);
} }
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@ -965,35 +962,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
// build gamma, one... // build gamma.
size_t acl_gamma_nb[GGML_MAX_DIMS]; size_t acl_gamma_nb[GGML_MAX_DIMS];
acl_gamma_nb[0] = sizeof(float); // gamma's type is the same with dst.
acl_gamma_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1]; acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
} }
aclTensor* acl_gamma = get_f32_cache_acl_tensor( aclTensor* acl_gamma = get_cache_acl_tensor(
ctx, ctx,
&ctx.rms_norm_one_tensor_cache.cache, &ctx.rms_norm_one_tensor_cache.cache,
ctx.rms_norm_one_tensor_cache.size, ctx.rms_norm_one_tensor_cache.size,
src->ne, src->ne,
acl_gamma_nb, acl_gamma_nb,
dst->type,
1, // dims 1, // dims
1.0f // value 1.0f // value
); );
// build rstd, zero... // build rstd.
int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]}; int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
size_t acl_rstd_nb[GGML_MAX_DIMS - 1]; size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
// rstd will always be F32.
acl_rstd_nb[0] = sizeof(float); acl_rstd_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1]; acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
} }
aclTensor* acl_rstd = get_f32_cache_acl_tensor( aclTensor* acl_rstd = get_cache_acl_tensor(
ctx, ctx,
&ctx.rms_norm_zero_tensor_cache.cache, &ctx.rms_norm_zero_tensor_cache.cache,
ctx.rms_norm_zero_tensor_cache.size, ctx.rms_norm_zero_tensor_cache.size,
acl_rstd_ne, acl_rstd_ne,
acl_rstd_nb, acl_rstd_nb,
GGML_TYPE_F32,
GGML_MAX_DIMS - 1, GGML_MAX_DIMS - 1,
0.0f // value 0.0f // value
); );
@ -1765,33 +1766,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0]; // src ggml_tensor* src0 = dst->src[0]; // src
ggml_tensor* src1 = dst->src[1]; // index ggml_tensor* src1 = dst->src[1]; // index
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F16:
case GGML_TYPE_F32:
if(src0->type == dst->type) {
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
dst->data, dst->ne, dst->nb, dst->data, dst->ne, dst->nb,
src1, dst->type); src1, dst->type);
break; } else {
}
case GGML_TYPE_F16: {
aclTensor* acl_src0 = ggml_cann_create_tensor(src0); aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator( ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float)); ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
void* src_trans_buffer = src_buffer_allocator.get(); void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS]; size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = sizeof(float); src_trans_nb[0] = dst->nb[0];
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
} }
aclTensor* src_trans_tensor = ggml_cann_create_tensor( aclTensor* src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type), src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src0->ne, src_trans_nb, GGML_MAX_DIMS); src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
dst->data, dst->ne, dst->nb, dst->data, dst->ne, dst->nb,
src1, dst->type); src1, dst->type);
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
break;
} }
break;
case GGML_TYPE_Q8_0: { case GGML_TYPE_Q8_0: {
// add 1 dim for bcast mul. // add 1 dim for bcast mul.
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1],
@ -1799,7 +1802,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1],
*dequant_ne; *dequant_ne;
int64_t scale_offset = 0; int64_t scale_offset = 0;
// [3,4,5,64] -> [3,4,5,2,32] // [3,4,5,64] -> [3,4,5,2,32]
weight_ne[0] = QK8_0; weight_ne[0] = QK8_0;
weight_ne[1] = src0->ne[0] / QK8_0; weight_ne[1] = src0->ne[0] / QK8_0;
@ -1809,7 +1811,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
weight_ne[i] = src0->ne[i - 1]; weight_ne[i] = src0->ne[i - 1];
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
} }
// [3,4,5,64] -> [3,4,5,2,1] // [3,4,5,64] -> [3,4,5,2,1]
scale_ne[0] = 1; scale_ne[0] = 1;
scale_ne[1] = src0->ne[0] / QK8_0; scale_ne[1] = src0->ne[0] / QK8_0;
@ -1819,18 +1820,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
scale_ne[i] = src0->ne[i - 1]; scale_ne[i] = src0->ne[i - 1];
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
} }
// [3,4,5,64] -> [3,4,5,2,32] // [3,4,5,64] -> [3,4,5,2,32]
dequant_ne = weight_ne; dequant_ne = weight_ne;
dequant_nb[0] = sizeof(float); dequant_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
} }
scale_offset = ggml_nelements(src0) * sizeof(int8_t); scale_offset = ggml_nelements(src0) * sizeof(int8_t);
ggml_cann_pool_alloc dequant_buffer_allocator( ggml_cann_pool_alloc dequant_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float)); ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type));
aclTensor* acl_weight_tensor = ggml_cann_create_tensor( aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
GGML_MAX_DIMS + 1); GGML_MAX_DIMS + 1);
@ -1838,22 +1836,20 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
aclTensor* dequant_tensor = ggml_cann_create_tensor( aclTensor* dequant_tensor = ggml_cann_create_tensor(
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float), dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
dequant_nb[0] = sizeof(float); dequant_nb[0] = ggml_type_size(dst->type);
dequant_ne = src0->ne; dequant_ne = src0->ne;
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
} }
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(),
dequant_ne, dequant_nb, dequant_ne, dequant_nb,
dst->data, dst->ne, dst->nb, dst->data, dst->ne, dst->nb,
src1, dst->type); src1, dst->type);
ggml_cann_release_resources(ctx, dequant_tensor); ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
break; break;
} }
default: default:
@ -1965,16 +1961,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
// Only check env once. // Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) { if (weight_to_nz && is_matmul_weight(weight)) {
int64_t acl_stride[2] = {1, transpose_ne[1]}; acl_weight_tensor =
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
// Reverse ne.
std::reverse(transpose_ne, transpose_ne + n_dims);
std::vector<int64_t> storageDims = {transpose_ne[0], transpose_ne[1]};
acl_weight_tensor = aclCreateTensor(
transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
} else { } else {
acl_weight_tensor = acl_weight_tensor =
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@ -3178,7 +3166,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclTensor* acl_src0_f16_tensor = nullptr; aclTensor* acl_src0_f16_tensor = nullptr;
aclTensor* acl_src1_f16_tensor = nullptr; aclTensor* acl_src1_f16_tensor = nullptr;
aclTensor* acl_src2_f16_tensor = nullptr; aclTensor* acl_src2_f16_tensor = nullptr;
aclTensor* acl_dst_f16_tensor = nullptr;
// Step 1: cast the src0 (Query) to fp16 if needed // Step 1: cast the src0 (Query) to fp16 if needed
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
@ -3216,22 +3203,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
src2_bsnd_nb, GGML_MAX_DIMS); src2_bsnd_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}
acl_dst_f16_tensor = ggml_cann_create_tensor(
out_f16_buffer, faDataType, faElemSize,
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
);
// Step 3: create the PSEShift tensor if needed // Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp // this tensor is considered as mask (f16) in the llama.cpp
aclTensor* bcast_pse_tensor = nullptr; aclTensor* bcast_pse_tensor = nullptr;
@ -3317,8 +3288,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclTensor* acl_q_tensor = acl_src0_f16_tensor; aclTensor* acl_q_tensor = acl_src0_f16_tensor;
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); aclTensorList* acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); aclTensorList* acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
int64_t numHeads = src0->ne[2]; // N int64_t numHeads = src0->ne[2]; // N
int64_t numKeyValueHeads = src1->ne[2]; int64_t numKeyValueHeads = src1->ne[2];
@ -3334,8 +3305,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
int64_t keyAntiquantMode = 0; int64_t keyAntiquantMode = 0;
int64_t valueAntiquantMode = 0; int64_t valueAntiquantMode = 0;
// Step 5: launch the FusedInferAttentionScoreV2 kernel. GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md aclTensor * fa_dst_tensor = nullptr;
aclTensor * acl_dst_tensor = nullptr;
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
if (dst->type == GGML_TYPE_F32) {
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}
fa_dst_tensor = ggml_cann_create_tensor(
out_f16_buffer, faDataType, faElemSize,
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
);
}
else {
fa_dst_tensor = ggml_cann_create_tensor(dst);
}
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
@ -3357,23 +3349,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
blockSize, antiquantMode, // blockSize, antiquantMode blockSize, antiquantMode, // blockSize, antiquantMode
softmaxLseFlag, // softmaxLseFlag softmaxLseFlag, // softmaxLseFlag
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
acl_dst_f16_tensor, // attentionOut fa_dst_tensor, // attentionOut
nullptr // softmaxLse nullptr // softmaxLse
); );
if (dst->type == GGML_TYPE_F32) {
// Step 6: post-processing, permute and cast to f32 // Step 6: post-processing, permute and cast to f32
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
// TODO: when dst is fp16, don't need cast aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,
acl_dst_f16_tensor,
acl_dst_tensor);
if(src3 != nullptr){
ggml_cann_release_resources(ctx, bcast_pse_tensor);
} }
}else{
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_k_tensor_list,
acl_v_tensor_list,
fa_dst_tensor,
acl_dst_tensor,
bcast_pse_tensor);
} else {
GGML_ABORT("Function is not implemented."); GGML_ABORT("Function is not implemented.");
} }
} }

View File

@ -68,7 +68,7 @@ struct ggml_compute_params {
#endif // __VXE2__ #endif // __VXE2__
#endif // __s390x__ && __VEC__ #endif // __s390x__ && __VEC__
#if defined(__ARM_FEATURE_SVE) #if defined(__ARM_FEATURE_SVE) && defined(__linux__)
#include <sys/prctl.h> #include <sys/prctl.h>
#endif #endif

View File

@ -689,8 +689,13 @@ bool ggml_is_numa(void) {
#endif #endif
static void ggml_init_arm_arch_features(void) { static void ggml_init_arm_arch_features(void) {
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE) #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
#if defined(__linux__)
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
#else
// TODO: add support of SVE for non-linux systems
#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
#endif
#endif #endif
} }

View File

@ -463,9 +463,9 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
#endif #endif
for (; i < n; ++i) { for (; i < n; ++i) {
float val = x[i] - mean; float val = x[i] - mean;
y[i] = val;
val *= val; val *= val;
sum += (ggml_float)val; sum += (ggml_float)val;
y[i] = val;
} }
return sum/n; return sum/n;
} }

View File

@ -941,13 +941,6 @@ struct ggml_cuda_graph {
std::vector<cudaGraphNode_t> nodes; std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params; std::vector<cudaKernelNodeParams> params;
std::vector<ggml_graph_node_properties> ggml_graph_properties; std::vector<ggml_graph_node_properties> ggml_graph_properties;
bool use_cpy_indirection = false;
std::vector<char *> cpy_dest_ptrs;
char ** dest_ptrs_d;
int dest_ptrs_size = 0;
// Index to allow each cpy kernel to be aware of it's position within the graph
// relative to other cpy nodes.
int graph_cpynode_index = -1;
#endif #endif
}; };

View File

@ -8,18 +8,16 @@
typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
template <cpy_kernel_t cpy_1> template <cpy_kernel_t cpy_1>
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne, static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int nb12, const int nb13) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) { if (i >= ne) {
return; return;
} }
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets // then combine those indices with the corresponding byte offsets to get the total offsets
const int64_t i03 = i/(ne00 * ne01 * ne02); const int64_t i03 = i/(ne00 * ne01 * ne02);
@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int nb12, const int nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
const int i03 = i/(ne00 * ne01 * ne02); const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne, static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int nb12, const int nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
const int i03 = i/(ne00 * ne01 * ne02); const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
cpy_blck(cx + x_offset, cdst + dst_offset); cpy_blck(cx + x_offset, cdst + dst_offset);
} }
// Copy destination pointers to GPU to be available when pointer indirection is in use
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
CUDA_CHECK(cudaStreamSynchronize(stream));
if (cuda_graph->dest_ptrs_d != nullptr) {
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
}
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
}
// copy destination pointers to GPU
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
cuda_graph->graph_cpynode_index = 0; // reset index
#else
GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream);
#endif
}
template<typename src_t, typename dst_t> template<typename src_t, typename dst_t>
static void ggml_cpy_flt_cuda( static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q8_0_cuda( static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK8_0 == 0); GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0; const int num_blocks = ne / QK8_0;
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q8_0_f32_cuda( static void ggml_cpy_q8_0_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q4_0_cuda( static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_0 == 0); GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0; const int num_blocks = ne / QK4_0;
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q4_0_f32_cuda( static void ggml_cpy_q4_0_f32_cuda(
@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q4_1_cuda( static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_1 == 0); GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1; const int num_blocks = ne / QK4_1;
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q4_1_f32_cuda( static void ggml_cpy_q4_1_f32_cuda(
@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_0_cuda( static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0); GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0; const int num_blocks = ne / QK5_0;
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q5_0_f32_cuda( static void ggml_cpy_q5_0_f32_cuda(
@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_1_cuda( static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0); GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1; const int num_blocks = ne / QK5_1;
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q5_1_f32_cuda( static void ggml_cpy_q5_1_f32_cuda(
@ -259,25 +233,25 @@ static void ggml_cpy_q5_1_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { cudaStream_t stream) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_iq4_nl_cuda( static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0); GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL; const int num_blocks = ne / QK4_NL;
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0); const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1)); GGML_ASSERT(ne == ggml_nelements(src1));
@ -311,16 +285,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph
char * src0_ddc = (char *) src0->data; char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data; char * src1_ddc = (char *) src1->data;
char ** dest_ptrs_d = nullptr;
int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if (cuda_graph && cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
dest_ptrs_d = cuda_graph->dest_ptrs_d;
graph_cpynode_index = cuda_graph->graph_cpynode_index;
}
#else
GGML_UNUSED(disable_indirection_for_this_node);
#endif
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
@ -329,134 +293,62 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph
} else } else
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{ {
if (src0->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
} }
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else { } else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type)); ggml_type_name(src0->type), ggml_type_name(src1->type));
} }
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if (cuda_graph && cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
cuda_graph->graph_cpynode_index = graph_cpynode_index;
}
#else
GGML_UNUSED(disable_indirection_for_this_node);
#endif
} }
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst) { void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
bool disable_indirection = true; ggml_cuda_cpy(ctx, src0, dst);
ggml_cuda_cpy(ctx, cuda_graph, src0, dst, disable_indirection);
}
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
// Prioritize CUDA graph compatibility over direct memory copy optimization.
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
if (src0->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else {
return nullptr;
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<float, half>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<half, half>>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<half, float>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
}
} }

View File

@ -2,10 +2,6 @@
#define CUDA_CPY_BLOCK_SIZE 64 #define CUDA_CPY_BLOCK_SIZE 64
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);

View File

@ -540,11 +540,13 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
} }
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ? if (!oob_check || i_KQ < k_VKQ_sup) {
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
} }
}
KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]); KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
} }
@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float KQ_sum_add = 0.0f; float KQ_sum_add = 0.0f;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]); const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) { expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val; KQ_sum_add += val;
}
tmp[i0/(np*warp_size)][jc1] = val; tmp[i0/(np*warp_size)][jc1] = val;
} }
KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
@ -975,26 +976,6 @@ static __global__ void flash_attn_tile(
} }
} }
if (gridDim.y == 1) {
#pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) {
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
#pragma unroll
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
}
#else
const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
#pragma unroll
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
}
#endif // FAST_FP16_AVAILABLE
}
}
// Write back results: // Write back results:
#pragma unroll #pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) { for (int jc0 = 0; jc0 < cpw; ++jc0) {
@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile(
return; return;
} }
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
#ifdef FAST_FP16_AVAILABLE #ifdef FAST_FP16_AVAILABLE
@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile(
#pragma unroll #pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) { for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
tmp[i1].x *= scale;
tmp[i1].y *= scale;
} }
if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile(
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
}
ggml_cuda_memcpy_1<cpy_ne_D*4>( ggml_cuda_memcpy_1<cpy_ne_D*4>(
&dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
&VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);

View File

@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
const int nwarps = nthreads / WARP_SIZE; const int nwarps = nthreads / WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>; fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = false; const bool need_f16_K = type_K == GGML_TYPE_F16;
constexpr bool need_f16_V = false; const bool need_f16_V = type_V == GGML_TYPE_F16;
constexpr size_t nbytes_shared = 0; constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
} }
@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst; const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
float logit_softcap; float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));

View File

@ -117,10 +117,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
} }
#define FATTN_VEC_CASE(D, type_K, type_V) \ #define FATTN_VEC_CASE(D, type_K, type_V) \
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ { \
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \ ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
return; \ return; \
} \ } \
} \
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ #define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
FATTN_VEC_CASE( 64, type_K, type_V) \ FATTN_VEC_CASE( 64, type_K, type_V) \
@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
#endif // GGML_CUDA_FA_ALL_QUANTS #endif // GGML_CUDA_FA_ALL_QUANTS
switch (K->type) { switch (K->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// If Turing tensor cores available, use them: // If Turing tensor cores available, use them:
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
if (can_use_vector_kernel) { if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
return BEST_FATTN_KERNEL_VEC; return BEST_FATTN_KERNEL_VEC;
} }
@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// If there are no tensor cores available, use the generic tile kernel: // If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) { if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (Q->ne[1] == 1) { if (Q->ne[1] == 1) {
if (!gqa_opt_applies) { if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_VEC; return BEST_FATTN_KERNEL_VEC;

View File

@ -2260,13 +2260,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, ggml_cuda
ggml_cuda_op_set_rows(ctx, dst); ggml_cuda_op_set_rows(ctx, dst);
break; break;
case GGML_OP_DUP: case GGML_OP_DUP:
ggml_cuda_dup(ctx, cuda_graph, dst); ggml_cuda_dup(ctx, dst);
break; break;
case GGML_OP_CPY: case GGML_OP_CPY:
ggml_cuda_cpy(ctx, cuda_graph, dst->src[0], dst->src[1]); ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
break; break;
case GGML_OP_CONT: case GGML_OP_CONT:
ggml_cuda_dup(ctx, cuda_graph, dst); ggml_cuda_dup(ctx, dst);
break; break;
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_ADD1: // TODO: more efficient implementation case GGML_OP_ADD1: // TODO: more efficient implementation
@ -2633,11 +2633,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
} }
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph, cudaStream_t stream, static bool check_node_graph_compatibility(const ggml_cgraph * cgraph,
bool use_cuda_graph) { bool use_cuda_graph) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph // Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_graph->cpy_dest_ptrs.clear();
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
@ -2688,33 +2687,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_cuda_graph
#endif #endif
} }
if (node->op == GGML_OP_CPY) {
// Store the pointers which are updated for each token, such that these can be sent
// to the device and accessed using indirection from CUDA graph
cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
}
}
if (!use_cuda_graph) { if (!use_cuda_graph) {
break; break;
} }
} }
if (use_cuda_graph) {
cuda_graph->use_cpy_indirection = true;
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
ggml_cuda_cpy_dest_ptrs_copy(cuda_graph, cuda_graph->cpy_dest_ptrs.data(), cuda_graph->cpy_dest_ptrs.size(), stream);
}
return use_cuda_graph; return use_cuda_graph;
} }
@ -2733,7 +2710,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address && if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_CPY &&
node->op != GGML_OP_VIEW) { node->op != GGML_OP_VIEW) {
return false; return false;
} }
@ -2754,7 +2730,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
for (int i = 0; i < GGML_MAX_SRC; i++) { for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] && if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] && node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_CPY &&
node->op != GGML_OP_VIEW node->op != GGML_OP_VIEW
) { ) {
return false; return false;
@ -2901,7 +2876,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
} }
//if rms norm is the B operand, then we don't handle broadcast //if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
return false; return false;
} }
@ -3105,13 +3080,11 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
} }
if (use_cuda_graph) { if (use_cuda_graph) {
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_graph, cgraph, cuda_ctx->stream(), use_cuda_graph); use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
} }
if (use_cuda_graph) { if (use_cuda_graph) {
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph); capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
} else {
cuda_graph->use_cpy_indirection = false;
} }
return cuda_graph; return cuda_graph;
@ -3138,7 +3111,7 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
bool use_cuda_graph = true; bool use_cuda_graph = true;
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_graph, cgraph, cuda_ctx->stream(), use_cuda_graph); use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
if (!use_cuda_graph) { if (!use_cuda_graph) {
if (cuda_graph->instance != nullptr) { if (cuda_graph->instance != nullptr) {
@ -3149,7 +3122,6 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
} }
cuda_graph->instance = nullptr; cuda_graph->instance = nullptr;
cuda_graph->graph = nullptr; cuda_graph->graph = nullptr;
cuda_graph->use_cpy_indirection = false;
} }
if (is_cuda_graph_update_required(cuda_graph, cgraph)) { if (is_cuda_graph_update_required(cuda_graph, cgraph)) {

View File

@ -1,5 +1,7 @@
#include "ggml.h" #include "ggml.h"
#include "mmf.cuh" #include "mmf.cuh"
#include "mmid.cuh"
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( src1->type == GGML_TYPE_F32);
@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
mmf_ids_data ids_info{};
mmf_ids_data * ids_info_ptr = nullptr;
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT: // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_dst = ids ? ne1 : ne2; const int64_t nchannels_dst = ids ? ne1 : ne2;
@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
nchannels_y = ids->ne[0]; nchannels_y = ids->ne[0];
} }
if (ids && ncols_dst > 16) {
const int64_t n_expert_used = ids->ne[0];
const int64_t n_experts = ne02;
const int64_t n_tokens = ne12;
const int64_t ne_get_rows = n_tokens * n_expert_used;
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
const int si1 = static_cast<int>(ids_s1);
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
GGML_ASSERT(sis1 > 0);
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
CUDA_CHECK(cudaGetLastError());
ids_info.ids_src_compact = ids_src_compact_dev.get();
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
ids_info.expert_bounds_dev = expert_bounds_dev.get();
ids_info.n_experts = static_cast<int>(n_experts);
ids_info.sis1 = sis1;
ids_info_ptr = &ids_info;
}
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data; const half2 * src0_d = (const half2 *) src0->data;
@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break; } break;
default: default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
} }
if (mul_mat_id) { if (mul_mat_id) {
if (type == GGML_TYPE_F32 && src1_ncols > 32) { if (src0_ne[1] <= 1024 && src1_ncols > 512) {
return false; return false;
} } else if(src0_ne[1] > 1024 && src1_ncols > 128) {
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
return false; return false;
} }
} else { } else {

View File

@ -7,6 +7,14 @@ using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32 #define MMF_ROWS_PER_BLOCK 32
struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr;
const int32_t * ids_dst_compact = nullptr;
const int32_t * expert_bounds_dev = nullptr;
int n_experts = 0;
int sis1 = 0;
};
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
@ -224,6 +232,250 @@ static __global__ void mul_mat_f(
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
} }
//This kernel is for larger batch sizes of mul_mat_id
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f_ids(
const T * __restrict__ x, const float * __restrict__ y,
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int expert_idx = blockIdx.y;
const int expert_start = expert_bounds[expert_idx];
const int expert_end = expert_bounds[expert_idx + 1];
const int ncols_expert = expert_end - expert_start;
const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
const int tile_idx = blockIdx.z;
if (tile_idx >= tiles_for_expert) {
return;
}
const int col_base = tile_idx * cols_per_block;
GGML_UNUSED(channel_ratio);
const int channel_x = expert_idx;
const int sample_dst = 0;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
y += int64_t(sample_y) *stride_sample_y;
dst += int64_t(sample_dst)*stride_sample_dst;
const int32_t * ids_src_expert = ids_src_compact + expert_start;
const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
extern __shared__ char data_mmv[];
char * compute_base = data_mmv;
//const float2 * y2 = (const float2 *) y;
tile_C C[ntA][ntB];
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
if constexpr (std::is_same_v<T, float>) {
float vals_buf[2][tile_B::I];
auto gather_tile = [&](int tile_idx_local, float *vals) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + tile_idx_local*tile_B::I;
const int global_j = col_base + j;
float val = 0.0f;
if (j < cols_per_block && global_j < ncols_expert) {
const int src_entry = ids_src_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
const int token = (int) qrm.x;
const int channel = (int) qrm.y;
if (token < ncols_dst_total) {
val = y[channel*stride_channel_y + token*stride_col_y + col];
}
}
vals[j0] = val;
}
};
gather_tile(0, vals_buf[0]);
int curr_buf = 0;
int next_buf = 1;
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
}
if (itB + 1 < ntB) {
gather_tile(itB + 1, vals_buf[next_buf]);
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
if (itB + 1 < ntB) {
curr_buf ^= 1;
next_buf ^= 1;
}
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
float2 vals_buf[2][tile_B::I];
auto gather_tile = [&](int tile_idx_local, float2 *vals) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + tile_idx_local*tile_B::I;
const int global_j = col_base + j;
float2 tmp = make_float2(0.0f, 0.0f);
if (j < cols_per_block && global_j < ncols_expert) {
const int src_entry = ids_src_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
const int token = (int) qrm.x;
const int channel = (int) qrm.y;
if (token < ncols_dst_total) {
tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
}
}
vals[j0] = tmp;
}
};
if (ntB > 0) {
gather_tile(0, vals_buf[0]);
}
int curr_buf = 0;
int next_buf = 1;
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const float2 tmp = vals_buf[curr_buf][j0];
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
if (itB + 1 < ntB) {
gather_tile(itB + 1, vals_buf[next_buf]);
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
if (itB + 1 < ntB) {
curr_buf ^= 1;
next_buf ^= 1;
}
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
const int global_j = col_base + j;
if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
const int dst_entry = ids_dst_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
const int token = (int) qrm.x;
if (token < ncols_dst_total) {
const int slot = (int) qrm.y;
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
}
}
}
#else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template<typename T, int cols_per_block, int nwarps> template<typename T, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids( static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, float * dst,
@ -232,11 +484,33 @@ static inline void mul_mat_f_switch_ids(
const int64_t stride_col_id, const int64_t stride_row_id, const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
if (ids) { const mmf_ids_data * ids_data) {
const bool has_ids_data = ids_data && ids_data->ids_src_compact;
// Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
// we prefer the normal mul_mat_f path with has_ids=true.
if (has_ids_data && ncols_dst > 16) {
const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
if (max_tiles == 0) {
return;
}
dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
sis1_fd, nch_fd);
} else if (ids) {
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
dim3 block_nums_ids = block_nums; dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles; block_nums_ids.y *= col_tiles;
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
@ -258,7 +532,7 @@ void mul_mat_f_cuda(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) { cudaStream_t stream, const mmf_ids_data * ids_data) {
typedef tile<16, 8, T> tile_A; typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B; typedef tile< 8, 8, T> tile_B;
@ -290,7 +564,7 @@ void mul_mat_f_cuda(
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1); const dim3 block_dims(warp_size, nwarps_best, 1);
@ -300,49 +574,57 @@ void mul_mat_f_cuda(
mul_mat_f_switch_ids<T, cols_per_block, 1>( mul_mat_f_switch_ids<T, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 2: { case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>( mul_mat_f_switch_ids<T, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 3: { case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>( mul_mat_f_switch_ids<T, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 4: { case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>( mul_mat_f_switch_ids<T, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 5: { case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>( mul_mat_f_switch_ids<T, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 6: { case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>( mul_mat_f_switch_ids<T, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 7: { case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>( mul_mat_f_switch_ids<T, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
case 8: { case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>( mul_mat_f_switch_ids<T, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) { cudaStream_t stream, const mmf_ids_data * ids_data) {
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block(
case 1: { case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 2: { case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 3: { case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 4: { case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 5: { case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 6: { case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 7: { case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 8: { case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 9: { case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 10: { case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 11: { case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 12: { case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 13: { case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 14: { case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 15: { case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 16: { case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream); cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \ #define DECL_MMF_CASE_EXTERN(ncols_dst) \

164
ggml/src/ggml-cuda/mmid.cu Normal file
View File

@ -0,0 +1,164 @@
#include "common.cuh"
#include "mmid.cuh"
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct mm_ids_helper_store {
uint32_t data;
__device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
data = (it & 0x003FFFFF) | (iex_used << 22);
}
__device__ uint32_t it() const {
return data & 0x003FFFFF;
}
__device__ uint32_t iex_used() const {
return data >> 22;
}
};
static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
const int expert = blockIdx.x;
extern __shared__ char data_mm_ids_helper[];
mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
int nex_prev = 0; // Number of columns for experts with a lower index.
int it_compact = 0; // Running index for the compact slice of this expert.
if constexpr (n_expert_used_template == 0) {
// Generic implementation:
for (int it = 0; it < n_tokens; ++it) {
int iex_used = -1; // The index at which the expert is used, if any.
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
const int expert_used = ids[it*si1 + iex];
nex_prev += expert_used < expert;
if (expert_used == expert) {
iex_used = iex;
}
}
if (iex_used != -1) {
store[it_compact] = mm_ids_helper_store(it, iex_used);
}
if (warp_reduce_any<warp_size>(iex_used != -1)) {
it_compact++;
}
}
} else {
// Implementation optimized for specific numbers of experts used:
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
const int it = it0 + threadIdx.x / neu_padded;
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
ids[it*si1 + iex] : INT_MAX;
const int iex_used = expert_used == expert ? iex : -1;
nex_prev += expert_used < expert;
// Whether the threads at this token position have used the expert:
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int it_compact_add_lower = 0;
#pragma unroll
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
it_compact_add_lower += tmp;
}
}
if (iex_used != -1) {
store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
}
}
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
const mm_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
}
if (threadIdx.x != 0) {
return;
}
expert_bounds[expert] = nex_prev;
if (expert < static_cast<int>(gridDim.x) - 1) {
return;
}
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
static void launch_mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
const int id = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
GGML_ASSERT(nbytes_shared <= smpbo);
mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
void ggml_cuda_launch_mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
switch (n_expert_used) {
case 2:
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 4:
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 6:
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 8:
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 16:
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 32:
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
default:
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
}
}

View File

@ -0,0 +1,5 @@
#pragma once
void ggml_cuda_launch_mm_ids_helper(
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);

View File

@ -1,141 +1,6 @@
#include "mmq.cuh" #include "mmq.cuh"
#include "quantize.cuh" #include "quantize.cuh"
#include "mmid.cuh"
#include <vector>
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct mmq_ids_helper_store {
uint32_t data;
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
data = (it & 0x003FFFFF) | (iex_used << 22);
}
__device__ uint32_t it() const {
return data & 0x003FFFFF;
}
__device__ uint32_t iex_used() const {
return data >> 22;
}
};
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
const int expert = blockIdx.x;
extern __shared__ char data_mmq_ids_helper[];
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
int nex_prev = 0; // Number of columns for experts with a lower index.
int it_compact = 0; // Running index for the compact slice of this expert.
if constexpr (n_expert_used_template == 0) {
// Generic implementation:
for (int it = 0; it < n_tokens; ++it) {
int iex_used = -1; // The index at which the expert is used, if any.
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
const int expert_used = ids[it*si1 + iex];
nex_prev += expert_used < expert;
if (expert_used == expert) {
iex_used = iex;
}
}
if (iex_used != -1) {
store[it_compact] = mmq_ids_helper_store(it, iex_used);
}
if (warp_reduce_any<warp_size>(iex_used != -1)) {
it_compact++;
}
}
} else {
// Implementation optimized for specific numbers of experts used:
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
const int it = it0 + threadIdx.x / neu_padded;
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
ids[it*si1 + iex] : INT_MAX;
const int iex_used = expert_used == expert ? iex : -1;
nex_prev += expert_used < expert;
// Whether the threads at this token position have used the expert:
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int it_compact_add_lower = 0;
#pragma unroll
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
it_compact_add_lower += tmp;
}
}
if (iex_used != -1) {
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
}
}
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
const mmq_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
}
if (threadIdx.x != 0) {
return;
}
expert_bounds[expert] = nex_prev;
if (expert < static_cast<int>(gridDim.x) - 1) {
return;
}
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
static void launch_mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
const int id = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
GGML_ASSERT(nbytes_shared <= smpbo);
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) { switch (args.type_x) {
@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q(
const int si1 = ids->nb[1] / ggml_element_size(ids); const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11; const int sis1 = nb12 / nb11;
switch (n_expert_used) { ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
case 2:
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream); ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 4:
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 6:
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 8:
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 16:
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 32:
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
default:
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
}
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
} }

View File

@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
static __global__ void mul_mat_vec_f( static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const int row = blockIdx.x; const int row = blockIdx.x;
const int channel_dst = blockIdx.y; const int channel_dst = blockIdx.y;
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z; const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio; const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
const int sample_y = sample_dst; const int sample_y = sample_dst;
const int tid = threadIdx.x; const int tid = threadIdx.x;
@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x*tmpy.x; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
sumf[j] += tmpx.y*tmpy.y; ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
} }
} }
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x * tmpy.x; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
sumf[j] += tmpx.y * tmpy.y; ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
} }
} }
} else { } else {
@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
} }
} else if constexpr (std::is_same_v<T, nv_bfloat16>) { } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
//TODO: add support for ggml_cuda_mad for hip_bfloat162
#if defined(GGML_USE_HIP)
const int * x2 = (const int *) x; const int * x2 = (const int *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2]; const int tmpx = x2[col2];
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x; const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y; const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
} }
} }
#else
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const nv_bfloat162 tmpx = x2[col2];
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
}
}
#endif
} else { } else {
static_assert(std::is_same_v<T, void>, "unsupported type"); static_assert(std::is_same_v<T, void>, "unsupported type");
} }
@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x; const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const int64_t sample_ratio = nsamples_dst / nsamples_x; const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device(); const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size; const int warp_size = ggml_cuda_info().devices[device].warp_size;
@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
case 32: { case 32: {
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 64: { case 64: {
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 96: { case 96: {
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 128: { case 128: {
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 160: { case 160: {
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 192: { case 192: {
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 224: { case 224: {
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 256: { case 256: {
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");

View File

@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
return res; return res;
} }
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_SUM);
char base[256];
char name[256];
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
@ -1482,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
return res; return res;
} }
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_SGD);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}

View File

@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
@ -134,6 +135,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib, ggml_metal_library_t lib,

View File

@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_COS: case GGML_OP_COS:
case GGML_OP_LOG: case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
@ -692,7 +693,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
// for new head sizes, add checks here // for new head sizes, add checks here
if (op->src[0]->ne[0] != 40 && if (op->src[0]->ne[0] != 32 &&
op->src[0]->ne[0] != 40 &&
op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 64 &&
op->src[0]->ne[0] != 80 && op->src[0]->ne[0] != 80 &&
op->src[0]->ne[0] != 96 && op->src[0]->ne[0] != 96 &&
@ -798,6 +800,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false; return false;
}; };
} }
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return has_simdgroup_reduction;
default: default:
return false; return false;
} }

View File

@ -544,6 +544,10 @@ typedef struct{
float limit; float limit;
} ggml_metal_kargs_glu; } ggml_metal_kargs_glu;
typedef struct {
uint64_t np;
} ggml_metal_kargs_sum;
typedef struct { typedef struct {
int64_t ne00; int64_t ne00;
int64_t ne01; int64_t ne01;
@ -773,4 +777,12 @@ typedef struct {
uint64_t nb01; uint64_t nb01;
} ggml_metal_kargs_argmax; } ggml_metal_kargs_argmax;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_adamw;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_sgd;
#endif // GGML_METAL_IMPL #endif // GGML_METAL_IMPL

View File

@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{ {
n_fuse = ggml_metal_op_glu(ctx, idx); n_fuse = ggml_metal_op_glu(ctx, idx);
} break; } break;
case GGML_OP_SUM:
{
n_fuse = ggml_metal_op_sum(ctx, idx);
} break;
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
{ {
@ -410,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{ {
n_fuse = ggml_metal_op_argmax(ctx, idx); n_fuse = ggml_metal_op_argmax(ctx, idx);
} break; } break;
case GGML_OP_OPT_STEP_ADAMW:
{
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
} break;
case GGML_OP_OPT_STEP_SGD:
{
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break;
default: default:
{ {
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@ -840,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
return 1; return 1;
} }
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
ggml_metal_kargs_sum args = {
/*.np =*/ n,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx); ggml_tensor * op = ctx->node(idx);
@ -3401,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
return 1; return 1;
} }
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_adamw args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_sgd args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}

View File

@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
@ -78,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
} }
} }
kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
if (tiitg != 0) {
return;
}
float acc = 0.0f;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
}
dst[0] = acc;
}
template <bool norm> template <bool norm>
kernel void kernel_sum_rows( kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args, constant ggml_metal_kargs_sum_rows & args,
@ -5195,8 +5213,30 @@ kernel void kernel_flash_attn_ext(
half, half4, simdgroup_half8x8 half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8 //float, float4, simdgroup_float8x8
#define FA_TYPES_F32 \
half, half4, simdgroup_half8x8, \
float, float4x4, simdgroup_float8x8, \
float, float4x4, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float2, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t; typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
@ -5209,6 +5249,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
@ -5221,6 +5262,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif #endif
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
@ -5232,6 +5274,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
@ -5243,6 +5286,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
@ -5254,6 +5298,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
@ -5265,6 +5310,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
@ -5800,8 +5846,28 @@ kernel void kernel_flash_attn_ext_vec(
float, float4, \ float, float4, \
float4 float4
#define FA_TYPES_F32 \
half4, \
float4, \
float4, \
float, \
float, float4, \
float4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
@ -5812,6 +5878,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
@ -5822,6 +5889,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
@ -5832,6 +5900,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
@ -5842,6 +5911,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
@ -5852,6 +5922,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
@ -5862,6 +5933,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>; template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
@ -8754,3 +8826,51 @@ kernel void kernel_pool_2d_avg_f32(
o_ptr[cur_oh * args.OW + cur_ow] = res; o_ptr[cur_oh * args.OW + cur_ow] = res;
} }
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}

View File

@ -2348,8 +2348,13 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false");
if (opencl_c_version.major >= 3) { if (opencl_c_version.major >= 3) {
// Assume it is not available for 3.0, since it is optional in 3.0.
// If compiling against 3.0, then we can query.
backend_ctx->non_uniform_workgroups = false;
#if CL_TARGET_OPENCL_VERSION >= 300
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool), CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool),
&backend_ctx->non_uniform_workgroups, 0)); &backend_ctx->non_uniform_workgroups, 0));
#endif
} else { } else {
GGML_ASSERT(opencl_c_version.major == 2); GGML_ASSERT(opencl_c_version.major == 2);
// Non-uniform workgroup sizes is mandatory feature in v2.x. // Non-uniform workgroup sizes is mandatory feature in v2.x.
@ -2681,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
// if rms_norm is the B operand, then we don't handle broadcast // if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && if (rms_norm == mul->src[1] &&
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { !ggml_are_same_shape(mul->src[0], rms_norm)) {
return false; return false;
} }

View File

@ -1,9 +1,18 @@
cmake_minimum_required(VERSION 3.19) cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW) cmake_policy(SET CMP0114 NEW)
cmake_policy(SET CMP0116 NEW) cmake_policy(SET CMP0116 NEW)
if (POLICY CMP0147)
# Parallel build custom build steps
cmake_policy(SET CMP0147 NEW)
endif()
find_package(Vulkan COMPONENTS glslc REQUIRED) find_package(Vulkan COMPONENTS glslc REQUIRED)
if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
# Parallel build object files
add_definitions(/MP)
endif()
function(detect_host_compiler) function(detect_host_compiler)
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)

View File

@ -2649,11 +2649,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
} \ } \
} }
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) { if (device->coopmat1_fa_support) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
@ -2661,6 +2663,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif #endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) { if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
} }
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
if (k->type == GGML_TYPE_F32) {
k_stride /= 4;
}
if (v->type == GGML_TYPE_F32) {
v_stride /= 4;
}
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
bool aligned = (KV % alignment) == 0 && bool aligned = (KV % alignment) == 0 &&
@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
} }
switch (op->src[1]->type) { switch (op->src[1]->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
// supported in scalar and coopmat2 paths // supported in scalar and coopmat2 paths

View File

@ -1,6 +1,18 @@
#include "types.glsl" #include "types.glsl"
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
vec4 block;
};
float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const vec4 v = bl.block;
const uint idx = coordInBlock[1];
const f16vec4 vf16 = f16vec4(v);
return vf16[idx];
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block; block_q4_0_packed16 block;
}; };
@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
#define dequantFuncA dequantFuncIQ4_NL #define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4) #elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4 #define dequantFuncA dequantFuncMXFP4
#elif defined(DATA_A_F32)
#define dequantFuncA dequantFuncF32
#endif #endif

View File

@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0 #define BINDING_IDX_K 0
#define BINDING_IDX_V 1 #define BINDING_IDX_V 1
#if defined(DATA_A_F32)
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
#elif defined(A_TYPE_PACKED16)
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif #endif
#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4
#define BLOCK_BYTE_SIZE 16
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
// iqs is currently always zero in the flash attention shaders
if (binding_idx == BINDING_IDX_K) {
return k_packed.k_data_packed[a_offset + ib];
} else {
return v_packed.v_data_packed[a_offset + ib];
}
}
#endif
#if defined(DATA_A_Q4_0) #if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18 #define BLOCK_BYTE_SIZE 18

View File

@ -313,12 +313,12 @@ void main() {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f); sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
} }
#else #else
ACC_TYPE sums[WMITER * TM * WNITER * TN]; ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
FLOAT_TYPE_VEC2 cache_b[TN]; FLOAT_TYPE_VEC2 cache_b;
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i] = ACC_TYPE(0.0f); sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
} }
#endif #endif
@ -360,20 +360,22 @@ void main() {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
} }
} }
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) {
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) { // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
} }
} }
} }
} }
} }
#endif #endif
@ -388,8 +390,9 @@ void main() {
} }
} }
#else #else
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
} }
#endif #endif
#endif #endif
@ -463,14 +466,21 @@ void main() {
const u16vec2 row_idx = row_ids[row_i - ic * BN]; const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID #endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
if (dr_warp + cr < p.M) { if (dr_warp + 2 * cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
} }
#else #else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) { if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
} }
#endif // MUL_MAT_ID #endif // MUL_MAT_ID
} }

View File

@ -611,9 +611,6 @@ void process_shaders() {
} }
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
if (tname == "f32") {
continue;
}
if (tname == "bf16") continue; if (tname == "bf16") continue;
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@ -630,7 +627,7 @@ void process_shaders() {
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") { } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
@ -639,7 +636,7 @@ void process_shaders() {
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") { } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);

View File

@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
} }
} }
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" : const char * swa_type_str = "unknown";
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" : switch (swa_type) {
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown"; case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
};
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@ -295,51 +300,68 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens; const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(kq_mask); const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data;
// [TAG_NO_CACHE_ISWA]
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int i1 = 0; i1 < n_tokens; ++i1) { for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0]; const llama_seq_id s1 = ubatch->seq_id[i1][0];
const llama_pos p1 = ubatch->pos[i1];
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
for (int i0 = 0; i0 < n_tokens; ++i0) { for (int i0 = 0; i0 < n_tokens; ++i0) {
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
const llama_seq_id s0 = ubatch->seq_id[i0][0]; const llama_seq_id s0 = ubatch->seq_id[i0][0];
const llama_pos p0 = ubatch->pos[i0];
// mask different sequences
if (s0 != s1) { if (s0 != s1) {
continue; // skip different sequences continue;
} }
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { // mask future tokens
continue; // skip future tokens for causal attention if (cparams.causal_attn && p0 > p1) {
continue;
} }
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA] // apply SWA if any
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
// continue; // skip masked tokens for SWA continue;
//} }
// TODO: reimplement this like in llama_kv_cache_unified data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
} else {
f = 0.0f;
}
}
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
} }
} }
} }
};
{
GGML_ASSERT(self_kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
float * data = (float *) self_kq_mask->data;
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
if (debug) {
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
}
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(self_kq_mask_swa);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
float * data = (float *) self_kq_mask_swa->data;
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
fill_mask(data, hparams.n_swa, hparams.swa_type);
if (debug) { if (debug) {
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
} }
}
} }
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
@ -1299,12 +1321,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
k = ggml_permute(ctx0, k, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3);
const auto n_kv = k->ne[1];
ggml_tensor * cur; ggml_tensor * cur;
// TODO: replace hardcoded padding with ggml-provided padding if (cparams.flash_attn && kq_b == nullptr) {
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
if (v_trans) { if (v_trans) {
@ -1419,10 +1438,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams); auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->kq_mask); ggml_set_input(inp->self_kq_mask);
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
} else {
inp->self_kq_mask_swa = nullptr;
inp->self_kq_mask_swa_cnv = nullptr;
}
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
} }
@ -1447,7 +1476,9 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto & kq_mask = inp->get_kq_mask(); const bool is_swa = hparams.is_swa(il);
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
// [TAG_NO_CACHE_PAD] // [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams

View File

@ -257,10 +257,14 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] // n_tokens == n_batch
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
const llama_hparams hparams; const llama_hparams hparams;
const llama_cparams cparams; const llama_cparams cparams;

View File

@ -11358,8 +11358,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
} }
}; };
struct llm_build_gemma_embedding_iswa : public llm_graph_context { struct llm_build_gemma_embedding : public llm_graph_context {
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k; const int64_t n_embd_head = hparams.n_embd_head_k;
ggml_tensor * cur; ggml_tensor * cur;
@ -11376,8 +11376,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
// inp_pos - contains the positions // inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA] auto * inp_attn = build_attn_inp_no_cache();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * inp_out_ids = build_inp_out_ids();
@ -19378,7 +19377,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT: case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_WAVTOKENIZER_DEC:
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA] case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_DREAM: case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA: case LLM_ARCH_LLADA:
case LLM_ARCH_LLADA_MOE: case LLM_ARCH_LLADA_MOE:
@ -19671,7 +19670,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break; } break;
case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_GEMMA_EMBEDDING:
{ {
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params); llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
} break; } break;
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
{ {

View File

@ -312,6 +312,7 @@ struct llama_model * llama_model_load_from_splits(
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__); LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
return nullptr; return nullptr;
} }
splits.reserve(n_paths);
for (size_t i = 0; i < n_paths; ++i) { for (size_t i = 0; i < n_paths; ++i) {
splits.push_back(paths[i]); splits.push_back(paths[i]);
} }

View File

@ -6779,7 +6779,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nb : { 1, 3, 32, 35, }) { for (int nb : { 1, 3, 32, 35, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext( test_cases.emplace_back(new test_flash_attn_ext(
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV)); hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
// run fewer test cases permuted // run fewer test cases permuted
@ -6911,7 +6911,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
} }
// qwen3-30b-a3b // qwen3-30b-a3b
for (int bs : {1, 4, 8, 32, 64, 128, 512}) { for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
for (ggml_type type_b : {GGML_TYPE_F32}) { for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
@ -6919,6 +6919,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
} }
} }
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
}
}
}
// gpt-oss-20b // gpt-oss-20b
for (int bs : {1, 4, 8, 512}) { for (int bs : {1, 4, 8, 512}) {
for (ggml_type type_a : {GGML_TYPE_MXFP4}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) {

Binary file not shown.

View File

@ -1585,23 +1585,31 @@ struct server_prompt_cache {
} }
} }
// average size per token
const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
// dynamically increase the token limit if it can fit in the memory limit
const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size/size_per_token) : limit_tokens;
if (limit_tokens > 0) { if (limit_tokens > 0) {
while (states.size() > 1 && n_tokens() > limit_tokens) { while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
if (states.empty()) { if (states.empty()) {
break; break;
} }
SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
states.pop_front(); states.pop_front();
} }
} }
SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n", SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens); states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
for (const auto & state : states) { for (const auto & state : states) {
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
(const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
} }
} }
}; };

View File

@ -50,6 +50,7 @@
"eslint-plugin-svelte": "^3.0.0", "eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2", "fflate": "^0.8.2",
"globals": "^16.0.0", "globals": "^16.0.0",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3", "mdsvex": "^0.12.3",
"playwright": "^1.53.0", "playwright": "^1.53.0",
"prettier": "^3.4.2", "prettier": "^3.4.2",
@ -66,6 +67,7 @@
"tw-animate-css": "^1.3.5", "tw-animate-css": "^1.3.5",
"typescript": "^5.0.0", "typescript": "^5.0.0",
"typescript-eslint": "^8.20.0", "typescript-eslint": "^8.20.0",
"unified": "^11.0.5",
"uuid": "^13.0.0", "uuid": "^13.0.0",
"vite": "^7.0.4", "vite": "^7.0.4",
"vite-plugin-devtools-json": "^0.2.0", "vite-plugin-devtools-json": "^0.2.0",
@ -2128,6 +2130,66 @@
"node": ">=14.0.0" "node": ">=14.0.0"
} }
}, },
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/core": {
"version": "1.4.3",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"@emnapi/wasi-threads": "1.0.2",
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/runtime": {
"version": "1.4.3",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/wasi-threads": {
"version": "1.0.2",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@napi-rs/wasm-runtime": {
"version": "0.2.11",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"@emnapi/core": "^1.4.3",
"@emnapi/runtime": "^1.4.3",
"@tybys/wasm-util": "^0.9.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@tybys/wasm-util": {
"version": "0.9.0",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/tslib": {
"version": "2.8.0",
"dev": true,
"inBundle": true,
"license": "0BSD",
"optional": true
},
"node_modules/@tailwindcss/oxide-win32-arm64-msvc": { "node_modules/@tailwindcss/oxide-win32-arm64-msvc": {
"version": "4.1.11", "version": "4.1.11",
"resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.11.tgz", "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.11.tgz",
@ -4946,6 +5008,13 @@
"url": "https://github.com/sponsors/wooorm" "url": "https://github.com/sponsors/wooorm"
} }
}, },
"node_modules/mdast": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/mdast/-/mdast-3.0.0.tgz",
"integrity": "sha512-xySmf8g4fPKMeC07jXGz971EkLbWAJ83s4US2Tj9lEdnZ142UP5grN73H1Xd3HzrdbU5o9GYYP/y8F9ZSwLE9g==",
"dev": true,
"license": "MIT"
},
"node_modules/mdast-util-find-and-replace": { "node_modules/mdast-util-find-and-replace": {
"version": "3.0.2", "version": "3.0.2",
"resolved": "https://registry.npmjs.org/mdast-util-find-and-replace/-/mdast-util-find-and-replace-3.0.2.tgz", "resolved": "https://registry.npmjs.org/mdast-util-find-and-replace/-/mdast-util-find-and-replace-3.0.2.tgz",

View File

@ -52,6 +52,7 @@
"eslint-plugin-svelte": "^3.0.0", "eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2", "fflate": "^0.8.2",
"globals": "^16.0.0", "globals": "^16.0.0",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3", "mdsvex": "^0.12.3",
"playwright": "^1.53.0", "playwright": "^1.53.0",
"prettier": "^3.4.2", "prettier": "^3.4.2",
@ -68,6 +69,7 @@
"tw-animate-css": "^1.3.5", "tw-animate-css": "^1.3.5",
"typescript": "^5.0.0", "typescript": "^5.0.0",
"typescript-eslint": "^8.20.0", "typescript-eslint": "^8.20.0",
"unified": "^11.0.5",
"uuid": "^13.0.0", "uuid": "^13.0.0",
"vite": "^7.0.4", "vite": "^7.0.4",
"vite-plugin-devtools-json": "^0.2.0", "vite-plugin-devtools-json": "^0.2.0",

View File

@ -7,6 +7,7 @@
ChatMessages, ChatMessages,
ChatProcessingInfo, ChatProcessingInfo,
EmptyFileAlertDialog, EmptyFileAlertDialog,
ChatErrorDialog,
ServerErrorSplash, ServerErrorSplash,
ServerInfo, ServerInfo,
ServerLoadingSplash, ServerLoadingSplash,
@ -22,10 +23,11 @@
activeMessages, activeMessages,
activeConversation, activeConversation,
deleteConversation, deleteConversation,
dismissErrorDialog,
errorDialog,
isLoading, isLoading,
sendMessage, sendMessage,
stopGeneration, stopGeneration
setMaxContextError
} from '$lib/stores/chat.svelte'; } from '$lib/stores/chat.svelte';
import { import {
supportsVision, supportsVision,
@ -34,7 +36,6 @@
serverWarning, serverWarning,
serverStore serverStore
} from '$lib/stores/server.svelte'; } from '$lib/stores/server.svelte';
import { contextService } from '$lib/services';
import { parseFilesToMessageExtras } from '$lib/utils/convert-files-to-extra'; import { parseFilesToMessageExtras } from '$lib/utils/convert-files-to-extra';
import { isFileTypeSupported } from '$lib/utils/file-type'; import { isFileTypeSupported } from '$lib/utils/file-type';
import { filterFilesByModalities } from '$lib/utils/modality-file-validation'; import { filterFilesByModalities } from '$lib/utils/modality-file-validation';
@ -79,6 +80,7 @@
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading() showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
); );
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading()); let isServerLoading = $derived(serverLoading());
async function handleDeleteConfirm() { async function handleDeleteConfirm() {
@ -105,6 +107,12 @@
} }
} }
function handleErrorDialogOpenChange(open: boolean) {
if (!open) {
dismissErrorDialog();
}
}
function handleDragOver(event: DragEvent) { function handleDragOver(event: DragEvent) {
event.preventDefault(); event.preventDefault();
} }
@ -183,21 +191,6 @@
const extras = result?.extras; const extras = result?.extras;
// Check context limit using real-time slots data
const contextCheck = await contextService.checkContextLimit();
if (contextCheck && contextCheck.wouldExceed) {
const errorMessage = contextService.getContextErrorMessage(contextCheck);
setMaxContextError({
message: errorMessage,
estimatedTokens: contextCheck.currentUsage,
maxContext: contextCheck.maxContext
});
return false;
}
// Enable autoscroll for user-initiated message sending // Enable autoscroll for user-initiated message sending
userScrolledUp = false; userScrolledUp = false;
autoScrollEnabled = true; autoScrollEnabled = true;
@ -461,6 +454,13 @@
}} }}
/> />
<ChatErrorDialog
message={activeErrorDialog?.message ?? ''}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? 'server'}
/>
<style> <style>
.conversation-chat-form { .conversation-chat-form {
position: relative; position: relative;

View File

@ -0,0 +1,60 @@
<script lang="ts">
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import { AlertTriangle, TimerOff } from '@lucide/svelte';
interface Props {
open: boolean;
type: 'timeout' | 'server';
message: string;
onOpenChange?: (open: boolean) => void;
}
let { open = $bindable(), type, message, onOpenChange }: Props = $props();
const isTimeout = $derived(type === 'timeout');
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
const description = $derived(
isTimeout
? 'The request did not receive a response from the server before timing out.'
: 'The server responded with an error message. Review the details below.'
);
const iconClass = $derived(isTimeout ? 'text-destructive' : 'text-amber-500');
const badgeClass = $derived(
isTimeout
? 'border-destructive/40 bg-destructive/10 text-destructive'
: 'border-amber-500/40 bg-amber-500/10 text-amber-600 dark:text-amber-400'
);
function handleOpenChange(newOpen: boolean) {
open = newOpen;
onOpenChange?.(newOpen);
}
</script>
<AlertDialog.Root {open} onOpenChange={handleOpenChange}>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title class="flex items-center gap-2">
{#if isTimeout}
<TimerOff class={`h-5 w-5 ${iconClass}`} />
{:else}
<AlertTriangle class={`h-5 w-5 ${iconClass}`} />
{/if}
{title}
</AlertDialog.Title>
<AlertDialog.Description>
{description}
</AlertDialog.Description>
</AlertDialog.Header>
<div class={`rounded-lg border px-4 py-3 text-sm ${badgeClass}`}>
<p class="font-medium">{message}</p>
</div>
<AlertDialog.Footer>
<AlertDialog.Action onclick={() => handleOpenChange(false)}>Close</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>

View File

@ -1,66 +0,0 @@
<script lang="ts">
import { AlertTriangle } from '@lucide/svelte';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import { maxContextError, clearMaxContextError } from '$lib/stores/chat.svelte';
</script>
<AlertDialog.Root
open={maxContextError() !== null}
onOpenChange={(open) => !open && clearMaxContextError()}
>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title class="flex items-center gap-2">
<AlertTriangle class="h-5 w-5 text-destructive" />
Message Too Long
</AlertDialog.Title>
<AlertDialog.Description>
Your message exceeds the model's context window and cannot be processed.
</AlertDialog.Description>
</AlertDialog.Header>
{#if maxContextError()}
<div class="space-y-3 text-sm">
<div class="rounded-lg bg-muted p-3">
<div class="mb-2 font-medium">Token Usage:</div>
<div class="space-y-1 text-muted-foreground">
<div>
Estimated tokens:
<span class="font-mono">
{maxContextError()?.estimatedTokens.toLocaleString()}
</span>
</div>
<div>
Context window:
<span class="font-mono">
{maxContextError()?.maxContext.toLocaleString()}
</span>
</div>
</div>
</div>
<div>
<div class="mb-2 font-medium">Suggestions:</div>
<ul class="list-inside list-disc space-y-1 text-muted-foreground">
<li>Shorten your message</li>
<li>Remove some file attachments</li>
<li>Start a new conversation</li>
</ul>
</div>
</div>
{/if}
<AlertDialog.Footer>
<AlertDialog.Action onclick={() => clearMaxContextError()}>Got it</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>

View File

@ -30,12 +30,11 @@ export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte'; export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte'; export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte';
export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte'; export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte';
export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte'; export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte';
export { default as MaximumContextAlertDialog } from './dialogs/MaximumContextAlertDialog.svelte';
export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte'; export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte';
export { default as MarkdownContent } from './misc/MarkdownContent.svelte'; export { default as MarkdownContent } from './misc/MarkdownContent.svelte';

View File

@ -14,6 +14,7 @@
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline'; import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
import githubLightCss from 'highlight.js/styles/github.css?inline'; import githubLightCss from 'highlight.js/styles/github.css?inline';
import { mode } from 'mode-watcher'; import { mode } from 'mode-watcher';
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
interface Props { interface Props {
content: string; content: string;
@ -50,36 +51,59 @@
.use(remarkGfm) // GitHub Flavored Markdown .use(remarkGfm) // GitHub Flavored Markdown
.use(remarkMath) // Parse $inline$ and $$block$$ math .use(remarkMath) // Parse $inline$ and $$block$$ math
.use(remarkBreaks) // Convert line breaks to <br> .use(remarkBreaks) // Convert line breaks to <br>
.use(remarkRehype) // Convert to rehype (HTML AST) .use(remarkLiteralHtml) // Treat raw HTML as literal text with preserved indentation
.use(remarkRehype) // Convert Markdown AST to rehype
.use(rehypeKatex) // Render math using KaTeX .use(rehypeKatex) // Render math using KaTeX
.use(rehypeHighlight) // Add syntax highlighting .use(rehypeHighlight) // Add syntax highlighting
.use(rehypeStringify); // Convert to HTML string .use(rehypeStringify); // Convert to HTML string
}); });
function enhanceLinks(html: string): string { function enhanceLinks(html: string): string {
if (!html.includes('<a')) {
return html;
}
const tempDiv = document.createElement('div'); const tempDiv = document.createElement('div');
tempDiv.innerHTML = html; tempDiv.innerHTML = html;
// Make all links open in new tabs // Make all links open in new tabs
const linkElements = tempDiv.querySelectorAll('a[href]'); const linkElements = tempDiv.querySelectorAll('a[href]');
let mutated = false;
for (const link of linkElements) { for (const link of linkElements) {
const target = link.getAttribute('target');
const rel = link.getAttribute('rel');
if (target !== '_blank' || rel !== 'noopener noreferrer') {
mutated = true;
}
link.setAttribute('target', '_blank'); link.setAttribute('target', '_blank');
link.setAttribute('rel', 'noopener noreferrer'); link.setAttribute('rel', 'noopener noreferrer');
} }
return tempDiv.innerHTML; return mutated ? tempDiv.innerHTML : html;
} }
function enhanceCodeBlocks(html: string): string { function enhanceCodeBlocks(html: string): string {
if (!html.includes('<pre')) {
return html;
}
const tempDiv = document.createElement('div'); const tempDiv = document.createElement('div');
tempDiv.innerHTML = html; tempDiv.innerHTML = html;
const preElements = tempDiv.querySelectorAll('pre'); const preElements = tempDiv.querySelectorAll('pre');
let mutated = false;
for (const [index, pre] of Array.from(preElements).entries()) { for (const [index, pre] of Array.from(preElements).entries()) {
const codeElement = pre.querySelector('code'); const codeElement = pre.querySelector('code');
if (!codeElement) continue; if (!codeElement) {
continue;
}
mutated = true;
let language = 'text'; let language = 'text';
const classList = Array.from(codeElement.classList); const classList = Array.from(codeElement.classList);
@ -127,7 +151,7 @@
pre.parentNode?.replaceChild(wrapper, pre); pre.parentNode?.replaceChild(wrapper, pre);
} }
return tempDiv.innerHTML; return mutated ? tempDiv.innerHTML : html;
} }
async function processMarkdown(text: string): Promise<string> { async function processMarkdown(text: string): Promise<string> {

View File

@ -0,0 +1,15 @@
export const LINE_BREAK = /\r?\n/;
export const PHRASE_PARENTS = new Set([
'paragraph',
'heading',
'emphasis',
'strong',
'delete',
'link',
'linkReference',
'tableCell'
]);
export const NBSP = '\u00a0';
export const TAB_AS_SPACES = NBSP.repeat(4);

View File

@ -0,0 +1,121 @@
import type { Plugin } from 'unified';
import { visit } from 'unist-util-visit';
import type { Break, Content, Paragraph, PhrasingContent, Root, Text } from 'mdast';
import { LINE_BREAK, NBSP, PHRASE_PARENTS, TAB_AS_SPACES } from '$lib/constants/literal-html';
/**
* remark plugin that rewrites raw HTML nodes into plain-text equivalents.
*
* remark parses inline HTML into `html` nodes even when we do not want to render
* them. We turn each of those nodes into regular text (plus `<br>` break markers)
* so the downstream rehype pipeline escapes the characters instead of executing
* them. Leading spaces and tab characters are converted to nonbreaking spaces to
* keep indentation identical to the original author input.
*/
function preserveIndent(line: string): string {
let index = 0;
let output = '';
while (index < line.length) {
const char = line[index];
if (char === ' ') {
output += NBSP;
index += 1;
continue;
}
if (char === '\t') {
output += TAB_AS_SPACES;
index += 1;
continue;
}
break;
}
return output + line.slice(index);
}
function createLiteralChildren(value: string): PhrasingContent[] {
const lines = value.split(LINE_BREAK);
const nodes: PhrasingContent[] = [];
for (const [lineIndex, rawLine] of lines.entries()) {
if (lineIndex > 0) {
nodes.push({ type: 'break' } as Break as unknown as PhrasingContent);
}
nodes.push({
type: 'text',
value: preserveIndent(rawLine)
} as Text as unknown as PhrasingContent);
}
if (!nodes.length) {
nodes.push({ type: 'text', value: '' } as Text as unknown as PhrasingContent);
}
return nodes;
}
export const remarkLiteralHtml: Plugin<[], Root> = () => {
return (tree) => {
visit(tree, 'html', (node, index, parent) => {
if (!parent || typeof index !== 'number') {
return;
}
const replacement = createLiteralChildren(node.value);
if (!PHRASE_PARENTS.has(parent.type as string)) {
const paragraph: Paragraph = {
type: 'paragraph',
children: replacement as Paragraph['children'],
data: { literalHtml: true }
};
const siblings = parent.children as unknown as Content[];
siblings.splice(index, 1, paragraph as unknown as Content);
if (index > 0) {
const previous = siblings[index - 1] as Paragraph | undefined;
if (
previous?.type === 'paragraph' &&
(previous.data as { literalHtml?: boolean } | undefined)?.literalHtml
) {
const prevChildren = previous.children as unknown as PhrasingContent[];
if (prevChildren.length) {
const lastChild = prevChildren[prevChildren.length - 1];
if (lastChild.type !== 'break') {
prevChildren.push({
type: 'break'
} as Break as unknown as PhrasingContent);
}
}
prevChildren.push(...(paragraph.children as unknown as PhrasingContent[]));
siblings.splice(index, 1);
return index;
}
}
return index + 1;
}
(parent.children as unknown as PhrasingContent[]).splice(
index,
1,
...(replacement as unknown as PhrasingContent[])
);
return index + replacement.length;
});
};
};

View File

@ -13,7 +13,7 @@ import { slotsService } from './slots';
* - Manages streaming and non-streaming response parsing * - Manages streaming and non-streaming response parsing
* - Provides request abortion capabilities * - Provides request abortion capabilities
* - Converts database messages to API format * - Converts database messages to API format
* - Handles error translation and context detection * - Handles error translation for server responses
* *
* - **ChatStore**: Stateful orchestration and UI state management * - **ChatStore**: Stateful orchestration and UI state management
* - Uses ChatService for all AI model communication * - Uses ChatService for all AI model communication
@ -26,7 +26,6 @@ import { slotsService } from './slots';
* - Streaming response handling with real-time callbacks * - Streaming response handling with real-time callbacks
* - Reasoning content extraction and processing * - Reasoning content extraction and processing
* - File attachment processing (images, PDFs, audio, text) * - File attachment processing (images, PDFs, audio, text)
* - Context error detection and reporting
* - Request lifecycle management (abort, cleanup) * - Request lifecycle management (abort, cleanup)
*/ */
export class ChatService { export class ChatService {
@ -209,10 +208,13 @@ export class ChatService {
userFriendlyError = new Error( userFriendlyError = new Error(
'Unable to connect to server - please check if the server is running' 'Unable to connect to server - please check if the server is running'
); );
userFriendlyError.name = 'NetworkError';
} else if (error.message.includes('ECONNREFUSED')) { } else if (error.message.includes('ECONNREFUSED')) {
userFriendlyError = new Error('Connection refused - server may be offline'); userFriendlyError = new Error('Connection refused - server may be offline');
userFriendlyError.name = 'NetworkError';
} else if (error.message.includes('ETIMEDOUT')) { } else if (error.message.includes('ETIMEDOUT')) {
userFriendlyError = new Error('Request timeout - server may be overloaded'); userFriendlyError = new Error('Request timed out - the server took too long to respond');
userFriendlyError.name = 'TimeoutError';
} else { } else {
userFriendlyError = error; userFriendlyError = error;
} }
@ -262,6 +264,7 @@ export class ChatService {
let fullReasoningContent = ''; let fullReasoningContent = '';
let hasReceivedData = false; let hasReceivedData = false;
let lastTimings: ChatMessageTimings | undefined; let lastTimings: ChatMessageTimings | undefined;
let streamFinished = false;
try { try {
let chunk = ''; let chunk = '';
@ -277,18 +280,8 @@ export class ChatService {
if (line.startsWith('data: ')) { if (line.startsWith('data: ')) {
const data = line.slice(6); const data = line.slice(6);
if (data === '[DONE]') { if (data === '[DONE]') {
if (!hasReceivedData && aggregatedContent.length === 0) { streamFinished = true;
const contextError = new Error( continue;
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
);
contextError.name = 'ContextError';
onError?.(contextError);
return;
}
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
return;
} }
try { try {
@ -326,13 +319,13 @@ export class ChatService {
} }
} }
if (streamFinished) {
if (!hasReceivedData && aggregatedContent.length === 0) { if (!hasReceivedData && aggregatedContent.length === 0) {
const contextError = new Error( const noResponseError = new Error('No response received from server. Please try again.');
'The request exceeds the available context size. Try increasing the context size or enable context shift.' throw noResponseError;
); }
contextError.name = 'ContextError';
onError?.(contextError); onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
return;
} }
} catch (error) { } catch (error) {
const err = error instanceof Error ? error : new Error('Stream error'); const err = error instanceof Error ? error : new Error('Stream error');
@ -368,12 +361,8 @@ export class ChatService {
const responseText = await response.text(); const responseText = await response.text();
if (!responseText.trim()) { if (!responseText.trim()) {
const contextError = new Error( const noResponseError = new Error('No response received from server. Please try again.');
'The request exceeds the available context size. Try increasing the context size or enable context shift.' throw noResponseError;
);
contextError.name = 'ContextError';
onError?.(contextError);
throw contextError;
} }
const data: ApiChatCompletionResponse = JSON.parse(responseText); const data: ApiChatCompletionResponse = JSON.parse(responseText);
@ -385,22 +374,14 @@ export class ChatService {
} }
if (!content.trim()) { if (!content.trim()) {
const contextError = new Error( const noResponseError = new Error('No response received from server. Please try again.');
'The request exceeds the available context size. Try increasing the context size or enable context shift.' throw noResponseError;
);
contextError.name = 'ContextError';
onError?.(contextError);
throw contextError;
} }
onComplete?.(content, reasoningContent); onComplete?.(content, reasoningContent);
return content; return content;
} catch (error) { } catch (error) {
if (error instanceof Error && error.name === 'ContextError') {
throw error;
}
const err = error instanceof Error ? error : new Error('Parse error'); const err = error instanceof Error ? error : new Error('Parse error');
onError?.(err); onError?.(err);
@ -594,37 +575,19 @@ export class ChatService {
const errorText = await response.text(); const errorText = await response.text();
const errorData: ApiErrorResponse = JSON.parse(errorText); const errorData: ApiErrorResponse = JSON.parse(errorText);
if (errorData.error?.type === 'exceed_context_size_error') {
const contextError = errorData.error as ApiContextSizeError;
const error = new Error(contextError.message);
error.name = 'ContextError';
// Attach structured context information
(
error as Error & {
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
}
).contextInfo = {
promptTokens: contextError.n_prompt_tokens,
maxContext: contextError.n_ctx,
estimatedTokens: contextError.n_prompt_tokens
};
return error;
}
// Fallback for other error types
const message = errorData.error?.message || 'Unknown server error'; const message = errorData.error?.message || 'Unknown server error';
return new Error(message); const error = new Error(message);
error.name = response.status === 400 ? 'ServerError' : 'HttpError';
return error;
} catch { } catch {
// If we can't parse the error response, return a generic error // If we can't parse the error response, return a generic error
return new Error(`Server error (${response.status}): ${response.statusText}`); const fallback = new Error(`Server error (${response.status}): ${response.statusText}`);
fallback.name = 'HttpError';
return fallback;
} }
} }
/**
* Updates the processing state with timing information from the server response
* @param timings - Timing data from the API response
* @param promptProgress - Progress data from the API response
*/
private updateProcessingState( private updateProcessingState(
timings?: ChatMessageTimings, timings?: ChatMessageTimings,
promptProgress?: ChatMessagePromptProgress promptProgress?: ChatMessagePromptProgress

View File

@ -1,102 +0,0 @@
import { slotsService } from './slots';
export interface ContextCheckResult {
wouldExceed: boolean;
currentUsage: number;
maxContext: number;
availableTokens: number;
reservedTokens: number;
}
/**
* ContextService - Context window management and limit checking
*
* This service provides context window monitoring and limit checking using real-time
* server data from the slots service. It helps prevent context overflow by tracking
* current usage and calculating available space for new content.
*
* **Architecture & Relationships:**
* - **ContextService** (this class): Context limit monitoring
* - Uses SlotsService for real-time context usage data
* - Calculates available tokens with configurable reserves
* - Provides context limit checking and error messaging
* - Helps prevent context window overflow
*
* - **SlotsService**: Provides current context usage from server slots
* - **ChatStore**: Uses context checking before sending messages
* - **UI Components**: Display context usage warnings and limits
*
* **Key Features:**
* - **Real-time Context Checking**: Uses live server data for accuracy
* - **Token Reservation**: Reserves tokens for response generation
* - **Limit Detection**: Prevents context window overflow
* - **Usage Reporting**: Detailed context usage statistics
* - **Error Messaging**: User-friendly context limit messages
* - **Configurable Reserves**: Adjustable token reservation for responses
*
* **Context Management:**
* - Monitors current context usage from active slots
* - Calculates available space considering reserved tokens
* - Provides early warning before context limits are reached
* - Helps optimize conversation length and content
*/
export class ContextService {
private reserveTokens: number;
constructor(reserveTokens = 512) {
this.reserveTokens = reserveTokens;
}
/**
* Checks if the context limit would be exceeded
*
* @returns {Promise<ContextCheckResult | null>} Promise that resolves to the context check result or null if an error occurs
*/
async checkContextLimit(): Promise<ContextCheckResult | null> {
try {
const currentState = await slotsService.getCurrentState();
if (!currentState) {
return null;
}
const maxContext = currentState.contextTotal;
const currentUsage = currentState.contextUsed;
const availableTokens = maxContext - currentUsage - this.reserveTokens;
const wouldExceed = availableTokens <= 0;
return {
wouldExceed,
currentUsage,
maxContext,
availableTokens: Math.max(0, availableTokens),
reservedTokens: this.reserveTokens
};
} catch (error) {
console.warn('Error checking context limit:', error);
return null;
}
}
/**
* Returns a formatted error message for context limit exceeded
*
* @param {ContextCheckResult} result - Context check result
* @returns {string} Formatted error message
*/
getContextErrorMessage(result: ContextCheckResult): string {
const usagePercent = Math.round((result.currentUsage / result.maxContext) * 100);
return `Context window is nearly full. Current usage: ${result.currentUsage.toLocaleString()}/${result.maxContext.toLocaleString()} tokens (${usagePercent}%). Available space: ${result.availableTokens.toLocaleString()} tokens (${result.reservedTokens} reserved for response).`;
}
/**
* Sets the number of tokens to reserve for response generation
*
* @param {number} tokens - Number of tokens to reserve
*/
setReserveTokens(tokens: number): void {
this.reserveTokens = tokens;
}
}
export const contextService = new ContextService();

View File

@ -1,3 +1,2 @@
export { chatService } from './chat'; export { chatService } from './chat';
export { contextService } from './context';
export { slotsService } from './slots'; export { slotsService } from './slots';

View File

@ -39,7 +39,6 @@ import type { ExportedConversations } from '$lib/types/database';
* - Conversation branching for exploring different response paths * - Conversation branching for exploring different response paths
* - Streaming AI responses with real-time content updates * - Streaming AI responses with real-time content updates
* - File attachment support (images, PDFs, text files, audio) * - File attachment support (images, PDFs, text files, audio)
* - Context window management with error recovery
* - Partial response saving when generation is interrupted * - Partial response saving when generation is interrupted
* - Message editing with automatic response regeneration * - Message editing with automatic response regeneration
*/ */
@ -48,11 +47,9 @@ class ChatStore {
activeMessages = $state<DatabaseMessage[]>([]); activeMessages = $state<DatabaseMessage[]>([]);
conversations = $state<DatabaseConversation[]>([]); conversations = $state<DatabaseConversation[]>([]);
currentResponse = $state(''); currentResponse = $state('');
errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null);
isInitialized = $state(false); isInitialized = $state(false);
isLoading = $state(false); isLoading = $state(false);
maxContextError = $state<{ message: string; estimatedTokens: number; maxContext: number } | null>(
null
);
titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise<boolean>; titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise<boolean>;
constructor() { constructor() {
@ -69,8 +66,6 @@ class ChatStore {
try { try {
await this.loadConversations(); await this.loadConversations();
this.maxContextError = null;
this.isInitialized = true; this.isInitialized = true;
} catch (error) { } catch (error) {
console.error('Failed to initialize chat store:', error); console.error('Failed to initialize chat store:', error);
@ -99,8 +94,6 @@ class ChatStore {
this.activeConversation = conversation; this.activeConversation = conversation;
this.activeMessages = []; this.activeMessages = [];
this.maxContextError = null;
await goto(`#/chat/${conversation.id}`); await goto(`#/chat/${conversation.id}`);
return conversation.id; return conversation.id;
@ -133,8 +126,6 @@ class ChatStore {
this.activeMessages = await DatabaseStore.getConversationMessages(convId); this.activeMessages = await DatabaseStore.getConversationMessages(convId);
} }
this.maxContextError = null;
return true; return true;
} catch (error) { } catch (error) {
console.error('Failed to load conversation:', error); console.error('Failed to load conversation:', error);
@ -418,56 +409,6 @@ class ChatStore {
return; return;
} }
if (error.name === 'ContextError') {
console.warn('Context error detected:', error.message);
this.isLoading = false;
this.currentResponse = '';
const messageIndex = this.activeMessages.findIndex(
(m: DatabaseMessage) => m.id === assistantMessage.id
);
if (messageIndex !== -1) {
this.activeMessages.splice(messageIndex, 1);
DatabaseStore.deleteMessage(assistantMessage.id).catch(console.error);
}
// Use structured context info from new exceed_context_size_error format if available
const contextInfo = (
error as Error & {
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
}
).contextInfo;
let estimatedTokens = 0;
let maxContext = serverStore.serverProps?.default_generation_settings.n_ctx || 8192;
if (contextInfo) {
// Use precise token counts from server response
estimatedTokens = contextInfo.promptTokens;
maxContext = contextInfo.maxContext;
} else {
// Fallback to estimation for older error format
try {
// Rough estimation: ~4 characters per token
const messageContent = JSON.stringify(messages);
estimatedTokens = Math.ceil(messageContent.length / 4);
} catch {
estimatedTokens = 0;
}
}
this.maxContextError = {
message: error.message,
estimatedTokens,
maxContext
};
if (onError) {
onError(error);
}
return;
}
console.error('Streaming error:', error); console.error('Streaming error:', error);
this.isLoading = false; this.isLoading = false;
this.currentResponse = ''; this.currentResponse = '';
@ -477,8 +418,18 @@ class ChatStore {
); );
if (messageIndex !== -1) { if (messageIndex !== -1) {
this.activeMessages[messageIndex].content = `Error: ${error.message}`; const [failedMessage] = this.activeMessages.splice(messageIndex, 1);
if (failedMessage) {
DatabaseStore.deleteMessage(failedMessage.id).catch((cleanupError) => {
console.error('Failed to remove assistant message after error:', cleanupError);
});
} }
}
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
this.showErrorDialog(dialogType, error.message);
if (onError) { if (onError) {
onError(error); onError(error);
@ -487,6 +438,14 @@ class ChatStore {
}); });
} }
private showErrorDialog(type: 'timeout' | 'server', message: string): void {
this.errorDialogState = { type, message };
}
dismissErrorDialog(): void {
this.errorDialogState = null;
}
/** /**
* Checks if an error is an abort error (user cancelled operation) * Checks if an error is an abort error (user cancelled operation)
* @param error - The error to check * @param error - The error to check
@ -574,6 +533,7 @@ class ChatStore {
return; return;
} }
this.errorDialogState = null;
this.isLoading = true; this.isLoading = true;
this.currentResponse = ''; this.currentResponse = '';
@ -603,37 +563,23 @@ class ChatStore {
const conversationContext = this.activeMessages.slice(0, -1); const conversationContext = this.activeMessages.slice(0, -1);
await this.streamChatCompletion( await this.streamChatCompletion(conversationContext, assistantMessage);
conversationContext,
assistantMessage,
undefined,
(error: Error) => {
if (error.name === 'ContextError' && userMessage) {
const userMessageIndex = this.findMessageIndex(userMessage.id);
if (userMessageIndex !== -1) {
this.activeMessages.splice(userMessageIndex, 1);
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
}
}
}
);
} catch (error) { } catch (error) {
if (this.isAbortError(error)) { if (this.isAbortError(error)) {
this.isLoading = false; this.isLoading = false;
return; return;
} }
if (error instanceof Error && error.name === 'ContextError' && userMessage) {
const userMessageIndex = this.findMessageIndex(userMessage.id);
if (userMessageIndex !== -1) {
this.activeMessages.splice(userMessageIndex, 1);
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
}
}
console.error('Failed to send message:', error); console.error('Failed to send message:', error);
this.isLoading = false; this.isLoading = false;
if (!this.errorDialogState) {
if (error instanceof Error) {
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
this.showErrorDialog(dialogType, error.message);
} else {
this.showErrorDialog('server', 'Unknown error occurred while sending message');
}
}
} }
} }
@ -662,24 +608,6 @@ class ChatStore {
this.currentResponse = ''; this.currentResponse = '';
} }
/**
* Clears the max context error state
* Removes any displayed context limit warnings
*/
clearMaxContextError(): void {
this.maxContextError = null;
}
/**
* Sets the max context error state
* @param error - The context error details or null to clear
*/
setMaxContextError(
error: { message: string; estimatedTokens: number; maxContext: number } | null
): void {
this.maxContextError = error;
}
/** /**
* Saves partial response if generation was interrupted * Saves partial response if generation was interrupted
* Preserves user's partial content and timing data when generation is stopped early * Preserves user's partial content and timing data when generation is stopped early
@ -1250,7 +1178,6 @@ class ChatStore {
this.activeMessages = []; this.activeMessages = [];
this.currentResponse = ''; this.currentResponse = '';
this.isLoading = false; this.isLoading = false;
this.maxContextError = null;
} }
/** Refreshes active messages based on currNode after branch navigation */ /** Refreshes active messages based on currNode after branch navigation */
@ -1538,6 +1465,7 @@ class ChatStore {
private async generateResponseForMessage(userMessageId: string): Promise<void> { private async generateResponseForMessage(userMessageId: string): Promise<void> {
if (!this.activeConversation) return; if (!this.activeConversation) return;
this.errorDialogState = null;
this.isLoading = true; this.isLoading = true;
this.currentResponse = ''; this.currentResponse = '';
@ -1584,7 +1512,7 @@ export const activeMessages = () => chatStore.activeMessages;
export const isLoading = () => chatStore.isLoading; export const isLoading = () => chatStore.isLoading;
export const currentResponse = () => chatStore.currentResponse; export const currentResponse = () => chatStore.currentResponse;
export const isInitialized = () => chatStore.isInitialized; export const isInitialized = () => chatStore.isInitialized;
export const maxContextError = () => chatStore.maxContextError; export const errorDialog = () => chatStore.errorDialogState;
export const createConversation = chatStore.createConversation.bind(chatStore); export const createConversation = chatStore.createConversation.bind(chatStore);
export const downloadConversation = chatStore.downloadConversation.bind(chatStore); export const downloadConversation = chatStore.downloadConversation.bind(chatStore);
@ -1592,9 +1520,9 @@ export const exportAllConversations = chatStore.exportAllConversations.bind(chat
export const importConversations = chatStore.importConversations.bind(chatStore); export const importConversations = chatStore.importConversations.bind(chatStore);
export const deleteConversation = chatStore.deleteConversation.bind(chatStore); export const deleteConversation = chatStore.deleteConversation.bind(chatStore);
export const sendMessage = chatStore.sendMessage.bind(chatStore); export const sendMessage = chatStore.sendMessage.bind(chatStore);
export const dismissErrorDialog = chatStore.dismissErrorDialog.bind(chatStore);
export const gracefulStop = chatStore.gracefulStop.bind(chatStore); export const gracefulStop = chatStore.gracefulStop.bind(chatStore);
export const clearMaxContextError = chatStore.clearMaxContextError.bind(chatStore);
export const setMaxContextError = chatStore.setMaxContextError.bind(chatStore);
// Branching operations // Branching operations
export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatStore); export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatStore);

View File

@ -197,7 +197,7 @@ class ServerStore {
errorMessage = 'Server not found - check server address'; errorMessage = 'Server not found - check server address';
isOfflineLikeError = true; isOfflineLikeError = true;
} else if (error.message.includes('ETIMEDOUT')) { } else if (error.message.includes('ETIMEDOUT')) {
errorMessage = 'Connection timeout - server may be overloaded'; errorMessage = 'Request timed out - the server took too long to respond';
isOfflineLikeError = true; isOfflineLikeError = true;
} else if (error.message.includes('503')) { } else if (error.message.includes('503')) {
errorMessage = 'Server temporarily unavailable - try again shortly'; errorMessage = 'Server temporarily unavailable - try again shortly';

View File

@ -1,11 +1,7 @@
<script lang="ts"> <script lang="ts">
import '../app.css'; import '../app.css';
import { page } from '$app/state'; import { page } from '$app/state';
import { import { ChatSidebar, ConversationTitleUpdateDialog } from '$lib/components/app';
ChatSidebar,
ConversationTitleUpdateDialog,
MaximumContextAlertDialog
} from '$lib/components/app';
import { import {
activeMessages, activeMessages,
isLoading, isLoading,
@ -145,8 +141,6 @@
<Toaster richColors /> <Toaster richColors />
<MaximumContextAlertDialog />
<ConversationTitleUpdateDialog <ConversationTitleUpdateDialog
bind:open={titleUpdateDialogOpen} bind:open={titleUpdateDialogOpen}
currentTitle={titleUpdateCurrentTitle} currentTitle={titleUpdateCurrentTitle}