Merge remote-tracking branch 'upstream' into cuda_graph_plan
This commit is contained in:
commit
c17f8b5bde
|
|
@ -146,9 +146,7 @@ void ggml_cann_op_unary_gated(
|
|||
unary_op(ctx, acl_src0, acl_dst);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_dst);
|
||||
if(src1)
|
||||
ggml_cann_release_resources(ctx, acl_src1);
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -894,14 +892,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
|||
}
|
||||
|
||||
/**
|
||||
* @brief Get or expand a cached float32 tensor filled with a scalar value.
|
||||
* @brief Get or expand a cached tensor filled with a scalar value.
|
||||
*
|
||||
* This function manages cached device memory for float32 tensors. If the current
|
||||
* This function manages cached device memory for tensors. If the current
|
||||
* cache size is insufficient for the requested tensor shape, the old memory will
|
||||
* be released and new memory will be allocated. The allocated buffer is then
|
||||
* initialized either with zeros (when @p value == 0.0f) or with the given scalar
|
||||
* value using CANN operations. Finally, an aclTensor object is created from the
|
||||
* cached memory and returned.
|
||||
* be released and new memory will be allocated. The allocated buffer is
|
||||
* initialized with the given scalar value using CANN operations.
|
||||
* Finally, an aclTensor object is created from the cached memory and returned.
|
||||
*
|
||||
* @param ctx The CANN backend context that manages device memory.
|
||||
* @param buffer A pointer to the cached device buffer (will be allocated
|
||||
|
|
@ -910,17 +907,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
|||
* updated when the cache is expanded.
|
||||
* @param ne The tensor shape array (number of elements in each dimension).
|
||||
* @param nb The stride size for each dimension.
|
||||
* @param dtype Data type of cached tensor.
|
||||
* @param dims The number of tensor dimensions.
|
||||
* @param value The scalar value used to fill the tensor (supports zero
|
||||
* initialization via memset or arbitrary values via fill_scalar).
|
||||
* @return An aclTensor pointer created from the cached buffer.
|
||||
*/
|
||||
static aclTensor* get_f32_cache_acl_tensor(
|
||||
static aclTensor* get_cache_acl_tensor(
|
||||
ggml_backend_cann_context& ctx,
|
||||
void** buffer,
|
||||
int64_t &cache_element,
|
||||
int64_t* ne,
|
||||
size_t* nb,
|
||||
ggml_type dtype,
|
||||
int64_t dims,
|
||||
float value) {
|
||||
// Calculate total number of elements
|
||||
|
|
@ -928,7 +927,7 @@ static aclTensor* get_f32_cache_acl_tensor(
|
|||
for (int i = 0; i < dims; i++) {
|
||||
n_element *= ne[i];
|
||||
}
|
||||
size_t size = n_element * sizeof(float);
|
||||
size_t size = n_element * ggml_type_size(dtype);
|
||||
|
||||
// Allocate or expand cache if needed
|
||||
if (cache_element < n_element) {
|
||||
|
|
@ -941,19 +940,17 @@ static aclTensor* get_f32_cache_acl_tensor(
|
|||
cache_element = n_element;
|
||||
|
||||
// Initialize cache
|
||||
if (value == 0.0f) {
|
||||
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
|
||||
} else {
|
||||
int64_t pool_ne[1] = { n_element };
|
||||
size_t pool_nb[1] = { sizeof(float) };
|
||||
aclTensor* acl_value = ggml_cann_create_tensor(
|
||||
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
|
||||
aclnn_fill_scalar(ctx, 1, acl_value);
|
||||
ggml_cann_release_resources(ctx, acl_value);
|
||||
}
|
||||
int64_t pool_ne[1] = { n_element };
|
||||
size_t pool_nb[1] = { ggml_type_size(dtype) };
|
||||
aclTensor* acl_value = ggml_cann_create_tensor(
|
||||
*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype),
|
||||
pool_ne, pool_nb, 1);
|
||||
aclnn_fill_scalar(ctx, value, acl_value);
|
||||
ggml_cann_release_resources(ctx, acl_value);
|
||||
}
|
||||
|
||||
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
|
||||
return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype),
|
||||
ggml_type_size(dtype), ne, nb, dims);
|
||||
}
|
||||
|
||||
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
|
|
@ -965,35 +962,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
// build gamma, one...
|
||||
// build gamma.
|
||||
size_t acl_gamma_nb[GGML_MAX_DIMS];
|
||||
acl_gamma_nb[0] = sizeof(float);
|
||||
// gamma's type is the same with dst.
|
||||
acl_gamma_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
|
||||
aclTensor* acl_gamma = get_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.rms_norm_one_tensor_cache.cache,
|
||||
ctx.rms_norm_one_tensor_cache.size,
|
||||
src->ne,
|
||||
acl_gamma_nb,
|
||||
dst->type,
|
||||
1, // dims
|
||||
1.0f // value
|
||||
);
|
||||
|
||||
// build rstd, zero...
|
||||
// build rstd.
|
||||
int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
|
||||
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
|
||||
// rstd will always be F32.
|
||||
acl_rstd_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
|
||||
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
|
||||
aclTensor* acl_rstd = get_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.rms_norm_zero_tensor_cache.cache,
|
||||
ctx.rms_norm_zero_tensor_cache.size,
|
||||
acl_rstd_ne,
|
||||
acl_rstd_nb,
|
||||
GGML_TYPE_F32,
|
||||
GGML_MAX_DIMS - 1,
|
||||
0.0f // value
|
||||
);
|
||||
|
|
@ -1765,33 +1766,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
ggml_tensor* src0 = dst->src[0]; // src
|
||||
ggml_tensor* src1 = dst->src[1]; // index
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_F16: {
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * sizeof(float));
|
||||
void* src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
if(src0->type == dst->type) {
|
||||
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
} else {
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
|
||||
void* src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = dst->nb[0];
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
|
||||
}
|
||||
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
|
||||
src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_Q8_0: {
|
||||
// add 1 dim for bcast mul.
|
||||
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1],
|
||||
|
|
@ -1799,7 +1802,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1],
|
||||
*dequant_ne;
|
||||
int64_t scale_offset = 0;
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
weight_ne[0] = QK8_0;
|
||||
weight_ne[1] = src0->ne[0] / QK8_0;
|
||||
|
|
@ -1809,7 +1811,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
weight_ne[i] = src0->ne[i - 1];
|
||||
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
|
||||
}
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,1]
|
||||
scale_ne[0] = 1;
|
||||
scale_ne[1] = src0->ne[0] / QK8_0;
|
||||
|
|
@ -1819,18 +1820,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
scale_ne[i] = src0->ne[i - 1];
|
||||
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
|
||||
}
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
dequant_ne = weight_ne;
|
||||
dequant_nb[0] = sizeof(float);
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
|
||||
}
|
||||
|
||||
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
|
||||
ggml_cann_pool_alloc dequant_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * sizeof(float));
|
||||
|
||||
ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type));
|
||||
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
|
||||
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
|
||||
GGML_MAX_DIMS + 1);
|
||||
|
|
@ -1838,22 +1836,20 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
|
||||
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
|
||||
aclTensor* dequant_tensor = ggml_cann_create_tensor(
|
||||
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
|
||||
dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
|
||||
|
||||
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
|
||||
dequant_nb[0] = sizeof(float);
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
dequant_ne = src0->ne;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
|
||||
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(),
|
||||
dequant_ne, dequant_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
|
||||
ggml_cann_release_resources(ctx, dequant_tensor);
|
||||
ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
@ -1965,16 +1961,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
|
|||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (weight_to_nz && is_matmul_weight(weight)) {
|
||||
int64_t acl_stride[2] = {1, transpose_ne[1]};
|
||||
|
||||
// Reverse ne.
|
||||
std::reverse(transpose_ne, transpose_ne + n_dims);
|
||||
|
||||
std::vector<int64_t> storageDims = {transpose_ne[0], transpose_ne[1]};
|
||||
|
||||
acl_weight_tensor = aclCreateTensor(
|
||||
transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
|
||||
0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
|
||||
acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
|
||||
} else {
|
||||
acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
|
||||
|
|
@ -3178,7 +3166,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
|||
aclTensor* acl_src0_f16_tensor = nullptr;
|
||||
aclTensor* acl_src1_f16_tensor = nullptr;
|
||||
aclTensor* acl_src2_f16_tensor = nullptr;
|
||||
aclTensor* acl_dst_f16_tensor = nullptr;
|
||||
|
||||
// Step 1: cast the src0 (Query) to fp16 if needed
|
||||
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
|
||||
|
|
@ -3216,22 +3203,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
|||
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
|
||||
src2_bsnd_nb, GGML_MAX_DIMS);
|
||||
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
void* out_f16_buffer = out_f16_allocator.alloc(
|
||||
ggml_nelements(dst) * faElemSize);
|
||||
|
||||
int64_t* out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
||||
}
|
||||
|
||||
acl_dst_f16_tensor = ggml_cann_create_tensor(
|
||||
out_f16_buffer, faDataType, faElemSize,
|
||||
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
|
||||
);
|
||||
|
||||
// Step 3: create the PSEShift tensor if needed
|
||||
// this tensor is considered as mask (f16) in the llama.cpp
|
||||
aclTensor* bcast_pse_tensor = nullptr;
|
||||
|
|
@ -3317,8 +3288,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
|||
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
|
||||
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
|
||||
aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
|
||||
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
|
||||
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
|
||||
aclTensorList* acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
|
||||
aclTensorList* acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
|
||||
|
||||
int64_t numHeads = src0->ne[2]; // N
|
||||
int64_t numKeyValueHeads = src1->ne[2];
|
||||
|
|
@ -3334,8 +3305,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
|||
int64_t keyAntiquantMode = 0;
|
||||
int64_t valueAntiquantMode = 0;
|
||||
|
||||
// Step 5: launch the FusedInferAttentionScoreV2 kernel.
|
||||
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
aclTensor * fa_dst_tensor = nullptr;
|
||||
aclTensor * acl_dst_tensor = nullptr;
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
void* out_f16_buffer = out_f16_allocator.alloc(
|
||||
ggml_nelements(dst) * faElemSize);
|
||||
|
||||
int64_t* out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
||||
}
|
||||
|
||||
fa_dst_tensor = ggml_cann_create_tensor(
|
||||
out_f16_buffer, faDataType, faElemSize,
|
||||
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
|
||||
);
|
||||
}
|
||||
else {
|
||||
fa_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
}
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
|
||||
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
|
||||
|
|
@ -3357,23 +3349,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
|||
blockSize, antiquantMode, // blockSize, antiquantMode
|
||||
softmaxLseFlag, // softmaxLseFlag
|
||||
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
|
||||
acl_dst_f16_tensor, // attentionOut
|
||||
fa_dst_tensor, // attentionOut
|
||||
nullptr // softmaxLse
|
||||
);
|
||||
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
// TODO: when dst is fp16, don't need cast
|
||||
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
||||
acl_src1_f16_tensor,
|
||||
acl_src2_f16_tensor,
|
||||
acl_dst_f16_tensor,
|
||||
acl_dst_tensor);
|
||||
if(src3 != nullptr){
|
||||
ggml_cann_release_resources(ctx, bcast_pse_tensor);
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
}
|
||||
}else{
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
||||
acl_k_tensor_list,
|
||||
acl_v_tensor_list,
|
||||
fa_dst_tensor,
|
||||
acl_dst_tensor,
|
||||
bcast_pse_tensor);
|
||||
|
||||
} else {
|
||||
GGML_ABORT("Function is not implemented.");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ struct ggml_compute_params {
|
|||
#endif // __VXE2__
|
||||
#endif // __s390x__ && __VEC__
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
|
||||
#include <sys/prctl.h>
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -689,8 +689,13 @@ bool ggml_is_numa(void) {
|
|||
#endif
|
||||
|
||||
static void ggml_init_arm_arch_features(void) {
|
||||
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__linux__)
|
||||
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
|
||||
#else
|
||||
// TODO: add support of SVE for non-linux systems
|
||||
#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -463,9 +463,9 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
|
|||
#endif
|
||||
for (; i < n; ++i) {
|
||||
float val = x[i] - mean;
|
||||
y[i] = val;
|
||||
val *= val;
|
||||
sum += (ggml_float)val;
|
||||
y[i] = val;
|
||||
}
|
||||
return sum/n;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -941,13 +941,6 @@ struct ggml_cuda_graph {
|
|||
std::vector<cudaGraphNode_t> nodes;
|
||||
std::vector<cudaKernelNodeParams> params;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
bool use_cpy_indirection = false;
|
||||
std::vector<char *> cpy_dest_ptrs;
|
||||
char ** dest_ptrs_d;
|
||||
int dest_ptrs_size = 0;
|
||||
// Index to allow each cpy kernel to be aware of it's position within the graph
|
||||
// relative to other cpy nodes.
|
||||
int graph_cpynode_index = -1;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -8,18 +8,16 @@
|
|||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int nb12, const int nb13) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
|
|
@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
|||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int nb12, const int nb13) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
|
|
@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
|
|||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int nb12, const int nb13) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
|
|
@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
|
|||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
// Copy destination pointers to GPU to be available when pointer indirection is in use
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
if (cuda_graph->dest_ptrs_d != nullptr) {
|
||||
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
|
||||
}
|
||||
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
|
||||
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
|
||||
}
|
||||
// copy destination pointers to GPU
|
||||
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
|
||||
cuda_graph->graph_cpynode_index = 0; // reset index
|
||||
#else
|
||||
GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q8_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int num_blocks = ne / QK8_0;
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q8_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_0 == 0);
|
||||
const int num_blocks = ne / QK4_0;
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_0_f32_cuda(
|
||||
|
|
@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda(
|
|||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_1 == 0);
|
||||
const int num_blocks = ne / QK4_1;
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_1_f32_cuda(
|
||||
|
|
@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda(
|
|||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_0 == 0);
|
||||
const int num_blocks = ne / QK5_0;
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_0_f32_cuda(
|
||||
|
|
@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda(
|
|||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_1 == 0);
|
||||
const int num_blocks = ne / QK5_1;
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_1_f32_cuda(
|
||||
|
|
@ -259,25 +233,25 @@ static void ggml_cpy_q5_1_f32_cuda(
|
|||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_NL == 0);
|
||||
const int num_blocks = ne / QK4_NL;
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
|
||||
|
|
@ -311,16 +285,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph
|
|||
char * src0_ddc = (char *) src0->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
char ** dest_ptrs_d = nullptr;
|
||||
int graph_cpynode_index = -1;
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if (cuda_graph && cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
dest_ptrs_d = cuda_graph->dest_ptrs_d;
|
||||
graph_cpynode_index = cuda_graph->graph_cpynode_index;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(disable_indirection_for_this_node);
|
||||
#endif
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
|
|
@ -329,134 +293,62 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph
|
|||
} else
|
||||
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
||||
{
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if (cuda_graph && cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(disable_indirection_for_this_node);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst) {
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
bool disable_indirection = true;
|
||||
ggml_cuda_cpy(ctx, cuda_graph, src0, dst, disable_indirection);
|
||||
}
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
// Prioritize CUDA graph compatibility over direct memory copy optimization.
|
||||
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, half>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, half>>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, float>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
ggml_cuda_cpy(ctx, src0, dst);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,10 +2,6 @@
|
|||
|
||||
#define CUDA_CPY_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst);
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|||
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
|
||||
}
|
||||
|
||||
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
|
||||
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
if (!oob_check || i_KQ < k_VKQ_sup) {
|
||||
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
|
||||
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
|
||||
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
|
||||
}
|
||||
}
|
||||
|
||||
KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
|
||||
|
|
@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
|
||||
const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
|
||||
if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) {
|
||||
KQ_sum_add += val;
|
||||
}
|
||||
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
|
||||
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/(np*warp_size)][jc1] = val;
|
||||
}
|
||||
KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
|
||||
|
|
@ -975,26 +976,6 @@ static __global__ void flash_attn_tile(
|
|||
}
|
||||
}
|
||||
|
||||
if (gridDim.y == 1) {
|
||||
#pragma unroll
|
||||
for (int jc0 = 0; jc0 < cpw; ++jc0) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
|
||||
VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
|
||||
}
|
||||
#else
|
||||
const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
|
||||
VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
|
||||
VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
// Write back results:
|
||||
#pragma unroll
|
||||
for (int jc0 = 0; jc0 < cpw; ++jc0) {
|
||||
|
|
@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile(
|
|||
return;
|
||||
}
|
||||
|
||||
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
|
|
@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile(
|
|||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
|
||||
tmp[i1].x *= scale;
|
||||
tmp[i1].y *= scale;
|
||||
}
|
||||
if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
||||
|
|
@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile(
|
|||
#pragma unroll
|
||||
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
|
||||
if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
||||
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
|
||||
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
|
||||
}
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
|
||||
&VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
|
||||
|
|
|
|||
|
|
@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
|
|||
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = false;
|
||||
constexpr bool need_f16_V = false;
|
||||
const bool need_f16_K = type_K == GGML_TYPE_F16;
|
||||
const bool need_f16_V = type_V == GGML_TYPE_F16;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
|
@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
|
|||
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
|
|
|||
|
|
@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|||
}
|
||||
}
|
||||
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
{ \
|
||||
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
|
||||
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
|
||||
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||
|
|
@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
switch (K->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
|
|
@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
// If Turing tensor cores available, use them:
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
|
@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (Q->ne[1] == 1) {
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
|
|
|
|||
|
|
@ -2260,13 +2260,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, ggml_cuda
|
|||
ggml_cuda_op_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
ggml_cuda_dup(ctx, cuda_graph, dst);
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
ggml_cuda_cpy(ctx, cuda_graph, dst->src[0], dst->src[1]);
|
||||
ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
|
||||
break;
|
||||
case GGML_OP_CONT:
|
||||
ggml_cuda_dup(ctx, cuda_graph, dst);
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1: // TODO: more efficient implementation
|
||||
|
|
@ -2633,11 +2633,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
|||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph, cudaStream_t stream,
|
||||
static bool check_node_graph_compatibility(const ggml_cgraph * cgraph,
|
||||
bool use_cuda_graph) {
|
||||
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||
|
|
@ -2688,33 +2687,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_cuda_graph
|
|||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
|
||||
// Store the pointers which are updated for each token, such that these can be sent
|
||||
// to the device and accessed using indirection from CUDA graph
|
||||
cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
|
||||
|
||||
// store a pointer to each copy op CUDA kernel to identify it later
|
||||
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
||||
if (!ptr) {
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
cuda_graph->use_cpy_indirection = true;
|
||||
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
|
||||
ggml_cuda_cpy_dest_ptrs_copy(cuda_graph, cuda_graph->cpy_dest_ptrs.data(), cuda_graph->cpy_dest_ptrs.size(), stream);
|
||||
}
|
||||
|
||||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
|
|
@ -2733,7 +2710,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
|
|||
|
||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
if (node->data != graph_node_properties->node_address &&
|
||||
node->op != GGML_OP_CPY &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -2754,7 +2730,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
|||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i] &&
|
||||
node->src[i]->data != graph_node_properties->src_address[i] &&
|
||||
node->op != GGML_OP_CPY &&
|
||||
node->op != GGML_OP_VIEW
|
||||
) {
|
||||
return false;
|
||||
|
|
@ -2901,7 +2876,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
}
|
||||
|
||||
//if rms norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -3105,13 +3080,11 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
|
|||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_graph, cgraph, cuda_ctx->stream(), use_cuda_graph);
|
||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
|
||||
} else {
|
||||
cuda_graph->use_cpy_indirection = false;
|
||||
}
|
||||
|
||||
return cuda_graph;
|
||||
|
|
@ -3138,7 +3111,7 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
|
|||
|
||||
bool use_cuda_graph = true;
|
||||
|
||||
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_graph, cgraph, cuda_ctx->stream(), use_cuda_graph);
|
||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
if (cuda_graph->instance != nullptr) {
|
||||
|
|
@ -3149,7 +3122,6 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
|
|||
}
|
||||
cuda_graph->instance = nullptr;
|
||||
cuda_graph->graph = nullptr;
|
||||
cuda_graph->use_cpy_indirection = false;
|
||||
}
|
||||
|
||||
if (is_cuda_graph_update_required(cuda_graph, cgraph)) {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#include "ggml.h"
|
||||
#include "mmf.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
|
|
@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
|
||||
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
|
||||
|
||||
mmf_ids_data ids_info{};
|
||||
mmf_ids_data * ids_info_ptr = nullptr;
|
||||
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
|
||||
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
|
||||
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
|
||||
|
||||
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||||
const int64_t ncols_dst = ids ? ne2 : ne1;
|
||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||
|
|
@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
nchannels_y = ids->ne[0];
|
||||
}
|
||||
|
||||
if (ids && ncols_dst > 16) {
|
||||
const int64_t n_expert_used = ids->ne[0];
|
||||
const int64_t n_experts = ne02;
|
||||
const int64_t n_tokens = ne12;
|
||||
const int64_t ne_get_rows = n_tokens * n_expert_used;
|
||||
|
||||
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
|
||||
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
|
||||
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
|
||||
|
||||
const int si1 = static_cast<int>(ids_s1);
|
||||
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
|
||||
|
||||
GGML_ASSERT(sis1 > 0);
|
||||
|
||||
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
|
||||
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
ids_info.ids_src_compact = ids_src_compact_dev.get();
|
||||
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
|
||||
ids_info.expert_bounds_dev = expert_bounds_dev.get();
|
||||
ids_info.n_experts = static_cast<int>(n_experts);
|
||||
ids_info.sis1 = sis1;
|
||||
ids_info_ptr = &ids_info;
|
||||
}
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
|
|
@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half2 * src0_d = (const half2 *) src0->data;
|
||||
|
|
@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||
|
|
@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||
|
|
@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
}
|
||||
|
||||
if (mul_mat_id) {
|
||||
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
|
||||
if (src0_ne[1] <= 1024 && src1_ncols > 512) {
|
||||
return false;
|
||||
}
|
||||
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
|
||||
} else if(src0_ne[1] > 1024 && src1_ncols > 128) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,14 @@ using namespace ggml_cuda_mma;
|
|||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
|
||||
struct mmf_ids_data {
|
||||
const int32_t * ids_src_compact = nullptr;
|
||||
const int32_t * ids_dst_compact = nullptr;
|
||||
const int32_t * expert_bounds_dev = nullptr;
|
||||
int n_experts = 0;
|
||||
int sis1 = 0;
|
||||
};
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
|
||||
|
|
@ -224,6 +232,250 @@ static __global__ void mul_mat_f(
|
|||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||
static __global__ void mul_mat_f_ids(
|
||||
const T * __restrict__ x, const float * __restrict__ y,
|
||||
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
|
||||
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
|
||||
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
const int row0 = blockIdx.x * rows_per_block;
|
||||
|
||||
const int expert_idx = blockIdx.y;
|
||||
const int expert_start = expert_bounds[expert_idx];
|
||||
const int expert_end = expert_bounds[expert_idx + 1];
|
||||
const int ncols_expert = expert_end - expert_start;
|
||||
|
||||
const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
|
||||
const int tile_idx = blockIdx.z;
|
||||
if (tile_idx >= tiles_for_expert) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int col_base = tile_idx * cols_per_block;
|
||||
|
||||
GGML_UNUSED(channel_ratio);
|
||||
|
||||
const int channel_x = expert_idx;
|
||||
const int sample_dst = 0;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_y = sample_dst;
|
||||
|
||||
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
|
||||
y += int64_t(sample_y) *stride_sample_y;
|
||||
dst += int64_t(sample_dst)*stride_sample_dst;
|
||||
|
||||
const int32_t * ids_src_expert = ids_src_compact + expert_start;
|
||||
const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
char * compute_base = data_mmv;
|
||||
|
||||
//const float2 * y2 = (const float2 *) y;
|
||||
|
||||
tile_C C[ntA][ntB];
|
||||
|
||||
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
|
||||
|
||||
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
||||
tile_A A[ntA][warp_size / tile_A::J];
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_A::I; ++i) {
|
||||
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
||||
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
float vals_buf[2][tile_B::I];
|
||||
auto gather_tile = [&](int tile_idx_local, float *vals) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + tile_idx_local*tile_B::I;
|
||||
const int global_j = col_base + j;
|
||||
float val = 0.0f;
|
||||
if (j < cols_per_block && global_j < ncols_expert) {
|
||||
const int src_entry = ids_src_expert[global_j];
|
||||
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
|
||||
const int token = (int) qrm.x;
|
||||
const int channel = (int) qrm.y;
|
||||
if (token < ncols_dst_total) {
|
||||
val = y[channel*stride_channel_y + token*stride_col_y + col];
|
||||
}
|
||||
}
|
||||
vals[j0] = val;
|
||||
}
|
||||
};
|
||||
|
||||
gather_tile(0, vals_buf[0]);
|
||||
|
||||
int curr_buf = 0;
|
||||
int next_buf = 1;
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
gather_tile(itB + 1, vals_buf[next_buf]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||
tile_B B;
|
||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||
}
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
curr_buf ^= 1;
|
||||
next_buf ^= 1;
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
float2 vals_buf[2][tile_B::I];
|
||||
auto gather_tile = [&](int tile_idx_local, float2 *vals) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + tile_idx_local*tile_B::I;
|
||||
const int global_j = col_base + j;
|
||||
float2 tmp = make_float2(0.0f, 0.0f);
|
||||
if (j < cols_per_block && global_j < ncols_expert) {
|
||||
const int src_entry = ids_src_expert[global_j];
|
||||
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
|
||||
const int token = (int) qrm.x;
|
||||
const int channel = (int) qrm.y;
|
||||
if (token < ncols_dst_total) {
|
||||
tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
|
||||
}
|
||||
}
|
||||
vals[j0] = tmp;
|
||||
}
|
||||
};
|
||||
|
||||
if (ntB > 0) {
|
||||
gather_tile(0, vals_buf[0]);
|
||||
}
|
||||
|
||||
int curr_buf = 0;
|
||||
int next_buf = 1;
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const float2 tmp = vals_buf[curr_buf][j0];
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
gather_tile(itB + 1, vals_buf[next_buf]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||
tile_B B;
|
||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||
}
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
curr_buf ^= 1;
|
||||
next_buf ^= 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
|
||||
const int j = itB*tile_C::J + tile_C::get_j(l);
|
||||
buf_iw[j*kiw + i] = C[itA][itB].x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
}
|
||||
|
||||
const int global_j = col_base + j;
|
||||
if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
|
||||
const int dst_entry = ids_dst_expert[global_j];
|
||||
const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
|
||||
const int token = (int) qrm.x;
|
||||
if (token < ncols_dst_total) {
|
||||
const int slot = (int) qrm.y;
|
||||
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
static inline void mul_mat_f_switch_ids(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
|
|
@ -232,13 +484,35 @@ static inline void mul_mat_f_switch_ids(
|
|||
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
||||
if (ids) {
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
|
||||
const mmf_ids_data * ids_data) {
|
||||
const bool has_ids_data = ids_data && ids_data->ids_src_compact;
|
||||
|
||||
// Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
|
||||
// we prefer the normal mul_mat_f path with has_ids=true.
|
||||
if (has_ids_data && ncols_dst > 16) {
|
||||
const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
|
||||
if (max_tiles == 0) {
|
||||
return;
|
||||
}
|
||||
dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
|
||||
|
||||
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
|
||||
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
|
||||
|
||||
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
|
||||
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
sis1_fd, nch_fd);
|
||||
} else if (ids) {
|
||||
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
|
||||
dim3 block_nums_ids = block_nums;
|
||||
block_nums_ids.y *= col_tiles;
|
||||
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
|
|
@ -258,7 +532,7 @@ void mul_mat_f_cuda(
|
|||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
|
||||
|
|
@ -290,7 +564,7 @@ void mul_mat_f_cuda(
|
|||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
|
||||
const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
|
||||
|
||||
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
|
||||
const dim3 block_dims(warp_size, nwarps_best, 1);
|
||||
|
|
@ -300,49 +574,57 @@ void mul_mat_f_cuda(
|
|||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
|
||||
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
|
||||
|
||||
|
|
@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
||||
cudaStream_t stream);
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data);
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||
|
|
|
|||
|
|
@ -0,0 +1,164 @@
|
|||
#include "common.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||
struct mm_ids_helper_store {
|
||||
uint32_t data;
|
||||
|
||||
__device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||
}
|
||||
|
||||
__device__ uint32_t it() const {
|
||||
return data & 0x003FFFFF;
|
||||
}
|
||||
|
||||
__device__ uint32_t iex_used() const {
|
||||
return data >> 22;
|
||||
}
|
||||
};
|
||||
static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
|
||||
|
||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||
// ids_dst describes the same mapping but for the dst tensor.
|
||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||
template <int n_expert_used_template>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mm_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||
const int expert = blockIdx.x;
|
||||
|
||||
extern __shared__ char data_mm_ids_helper[];
|
||||
mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
|
||||
|
||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
// Generic implementation:
|
||||
for (int it = 0; it < n_tokens; ++it) {
|
||||
int iex_used = -1; // The index at which the expert is used, if any.
|
||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||
const int expert_used = ids[it*si1 + iex];
|
||||
nex_prev += expert_used < expert;
|
||||
if (expert_used == expert) {
|
||||
iex_used = iex;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact] = mm_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||
it_compact++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Implementation optimized for specific numbers of experts used:
|
||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||
const int it = it0 + threadIdx.x / neu_padded;
|
||||
|
||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||
ids[it*si1 + iex] : INT_MAX;
|
||||
const int iex_used = expert_used == expert ? iex : -1;
|
||||
nex_prev += expert_used < expert;
|
||||
|
||||
// Whether the threads at this token position have used the expert:
|
||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||
|
||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||
int it_compact_add_lower = 0;
|
||||
#pragma unroll
|
||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
|
||||
it_compact_add_lower += tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||
}
|
||||
}
|
||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||
|
||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||
const mm_ids_helper_store store_it = store[itc];
|
||||
const int it = store_it.it();
|
||||
const int iex_used = store_it.iex_used();
|
||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||
}
|
||||
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[expert] = nex_prev;
|
||||
|
||||
if (expert < static_cast<int>(gridDim.x) - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||
}
|
||||
|
||||
template <int n_expert_used_template>
|
||||
static void launch_mm_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
|
||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
|
||||
|
||||
const dim3 num_blocks(n_experts, 1, 1);
|
||||
const dim3 block_size(warp_size, 1, 1);
|
||||
const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
|
||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||
mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||
}
|
||||
|
||||
void ggml_cuda_launch_mm_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
switch (n_expert_used) {
|
||||
case 2:
|
||||
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
default:
|
||||
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
void ggml_cuda_launch_mm_ids_helper(
|
||||
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
|
||||
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);
|
||||
|
|
@ -1,141 +1,6 @@
|
|||
#include "mmq.cuh"
|
||||
#include "quantize.cuh"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||
struct mmq_ids_helper_store {
|
||||
uint32_t data;
|
||||
|
||||
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||
}
|
||||
|
||||
__device__ uint32_t it() const {
|
||||
return data & 0x003FFFFF;
|
||||
}
|
||||
|
||||
__device__ uint32_t iex_used() const {
|
||||
return data >> 22;
|
||||
}
|
||||
};
|
||||
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
|
||||
|
||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||
// ids_dst describes the same mapping but for the dst tensor.
|
||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||
template <int n_expert_used_template>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||
const int expert = blockIdx.x;
|
||||
|
||||
extern __shared__ char data_mmq_ids_helper[];
|
||||
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
|
||||
|
||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
// Generic implementation:
|
||||
for (int it = 0; it < n_tokens; ++it) {
|
||||
int iex_used = -1; // The index at which the expert is used, if any.
|
||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||
const int expert_used = ids[it*si1 + iex];
|
||||
nex_prev += expert_used < expert;
|
||||
if (expert_used == expert) {
|
||||
iex_used = iex;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||
it_compact++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Implementation optimized for specific numbers of experts used:
|
||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||
const int it = it0 + threadIdx.x / neu_padded;
|
||||
|
||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||
ids[it*si1 + iex] : INT_MAX;
|
||||
const int iex_used = expert_used == expert ? iex : -1;
|
||||
nex_prev += expert_used < expert;
|
||||
|
||||
// Whether the threads at this token position have used the expert:
|
||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||
|
||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||
int it_compact_add_lower = 0;
|
||||
#pragma unroll
|
||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
|
||||
it_compact_add_lower += tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||
}
|
||||
}
|
||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||
|
||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||
const mmq_ids_helper_store store_it = store[itc];
|
||||
const int it = store_it.it();
|
||||
const int iex_used = store_it.iex_used();
|
||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||
}
|
||||
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[expert] = nex_prev;
|
||||
|
||||
if (expert < static_cast<int>(gridDim.x) - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||
}
|
||||
|
||||
template <int n_expert_used_template>
|
||||
static void launch_mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
|
||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
|
||||
|
||||
const dim3 num_blocks(n_experts, 1, 1);
|
||||
const dim3 block_size(warp_size, 1, 1);
|
||||
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
|
||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||
}
|
||||
#include "mmid.cuh"
|
||||
|
||||
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||
switch (args.type_x) {
|
||||
|
|
@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q(
|
|||
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
||||
const int sis1 = nb12 / nb11;
|
||||
|
||||
switch (n_expert_used) {
|
||||
case 2:
|
||||
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
default:
|
||||
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
|
|||
static __global__ void mul_mat_vec_f(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const int row = blockIdx.x;
|
||||
const int channel_dst = blockIdx.y;
|
||||
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
|
||||
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
|
||||
const int sample_y = sample_dst;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
|
|
@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
|
|||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += tmpx.x*tmpy.x;
|
||||
sumf[j] += tmpx.y*tmpy.y;
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half>) {
|
||||
|
|
@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
|
|||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += tmpx.x * tmpy.x;
|
||||
sumf[j] += tmpx.y * tmpy.y;
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
|
|||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
|
||||
//TODO: add support for ggml_cuda_mad for hip_bfloat162
|
||||
#if defined(GGML_USE_HIP)
|
||||
const int * x2 = (const int *) x;
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const int tmpx = x2[col2];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
||||
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
||||
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
|
||||
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
|
||||
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
|
||||
}
|
||||
}
|
||||
#else
|
||||
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const nv_bfloat162 tmpx = x2[col2];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
|
|
@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
|
|||
GGML_ASSERT(stride_col_y % 2 == 0);
|
||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||||
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
|
||||
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
|
|
@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
|
|||
case 32: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
|||
|
|
@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_SUM);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
||||
|
||||
|
|
@ -1482,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_OPT_STEP_SGD);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
|
|||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
|
@ -134,6 +135,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
|
|||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
|
|
|
|||
|
|
@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
|
@ -692,7 +693,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
// for new head sizes, add checks here
|
||||
if (op->src[0]->ne[0] != 40 &&
|
||||
if (op->src[0]->ne[0] != 32 &&
|
||||
op->src[0]->ne[0] != 40 &&
|
||||
op->src[0]->ne[0] != 64 &&
|
||||
op->src[0]->ne[0] != 80 &&
|
||||
op->src[0]->ne[0] != 96 &&
|
||||
|
|
@ -798,6 +800,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
return false;
|
||||
};
|
||||
}
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return has_simdgroup_reduction;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -544,6 +544,10 @@ typedef struct{
|
|||
float limit;
|
||||
} ggml_metal_kargs_glu;
|
||||
|
||||
typedef struct {
|
||||
uint64_t np;
|
||||
} ggml_metal_kargs_sum;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
|
|
@ -773,4 +777,12 @@ typedef struct {
|
|||
uint64_t nb01;
|
||||
} ggml_metal_kargs_argmax;
|
||||
|
||||
typedef struct {
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_opt_step_adamw;
|
||||
|
||||
typedef struct {
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_opt_step_sgd;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
|
|
|
|||
|
|
@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_glu(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SUM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_sum(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
|
|
@ -410,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_argmax(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
{
|
||||
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
{
|
||||
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
||||
|
|
@ -840,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
|
||||
|
||||
ggml_metal_kargs_sum args = {
|
||||
/*.np =*/ n,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
|
@ -3401,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_adamw args = {
|
||||
/*.np =*/ np,
|
||||
};
|
||||
|
||||
int ida = 0;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
const int64_t n = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_sgd args = {
|
||||
/*.np =*/ np,
|
||||
};
|
||||
|
||||
int ida = 0;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
const int64_t n = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
|
|
@ -78,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_op_sum_f32(
|
||||
constant ggml_metal_kargs_sum & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||
|
||||
if (tiitg != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
float acc = 0.0f;
|
||||
for (ulong i = 0; i < args.np; ++i) {
|
||||
acc += src0[i];
|
||||
}
|
||||
|
||||
dst[0] = acc;
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
|
|
@ -5195,8 +5213,30 @@ kernel void kernel_flash_attn_ext(
|
|||
half, half4, simdgroup_half8x8
|
||||
//float, float4, simdgroup_float8x8
|
||||
|
||||
#define FA_TYPES_F32 \
|
||||
half, half4, simdgroup_half8x8, \
|
||||
float, float4x4, simdgroup_float8x8, \
|
||||
float, float4x4, simdgroup_float8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
float, float2, simdgroup_float8x8, \
|
||||
float, float4, simdgroup_float8x8
|
||||
//half, half4, simdgroup_half8x8
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
|
|
@ -5209,6 +5249,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
|
|||
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
|
|
@ -5221,6 +5262,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
|
|||
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
|
|
@ -5232,6 +5274,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
|
|||
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
|
|
@ -5243,6 +5286,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
|
|||
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
|
|
@ -5254,6 +5298,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
|
|||
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
|
|
@ -5265,6 +5310,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
|
|||
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
|
|
@ -5800,77 +5846,103 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
float, float4, \
|
||||
float4
|
||||
|
||||
#define FA_TYPES_F32 \
|
||||
half4, \
|
||||
float4, \
|
||||
float4, \
|
||||
float, \
|
||||
float, float4, \
|
||||
float4
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
|
|
@ -8754,3 +8826,51 @@ kernel void kernel_pool_2d_avg_f32(
|
|||
|
||||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_adamw_f32(
|
||||
constant ggml_metal_kargs_opt_step_adamw & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device float * g_m,
|
||||
device float * g_v,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float alpha = pars[0];
|
||||
const float beta1 = pars[1];
|
||||
const float beta2 = pars[2];
|
||||
const float eps = pars[3];
|
||||
const float wd = pars[4];
|
||||
const float beta1h = pars[5];
|
||||
const float beta2h = pars[6];
|
||||
|
||||
const float gi = g[gid];
|
||||
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
|
||||
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
|
||||
|
||||
g_m[gid] = gmi;
|
||||
g_v[gid] = gvi;
|
||||
|
||||
const float mh = gmi * beta1h;
|
||||
const float vh = sqrt(gvi * beta2h) + eps;
|
||||
|
||||
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_sgd_f32(
|
||||
constant ggml_metal_kargs_opt_step_sgd & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2348,8 +2348,13 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
|||
svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false");
|
||||
|
||||
if (opencl_c_version.major >= 3) {
|
||||
// Assume it is not available for 3.0, since it is optional in 3.0.
|
||||
// If compiling against 3.0, then we can query.
|
||||
backend_ctx->non_uniform_workgroups = false;
|
||||
#if CL_TARGET_OPENCL_VERSION >= 300
|
||||
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool),
|
||||
&backend_ctx->non_uniform_workgroups, 0));
|
||||
#endif
|
||||
} else {
|
||||
GGML_ASSERT(opencl_c_version.major == 2);
|
||||
// Non-uniform workgroup sizes is mandatory feature in v2.x.
|
||||
|
|
@ -2681,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
|||
|
||||
// if rms_norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] &&
|
||||
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
!ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,18 @@
|
|||
cmake_minimum_required(VERSION 3.19)
|
||||
cmake_policy(SET CMP0114 NEW)
|
||||
cmake_policy(SET CMP0116 NEW)
|
||||
if (POLICY CMP0147)
|
||||
# Parallel build custom build steps
|
||||
cmake_policy(SET CMP0147 NEW)
|
||||
endif()
|
||||
|
||||
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
|
||||
# Parallel build object files
|
||||
add_definitions(/MP)
|
||||
endif()
|
||||
|
||||
function(detect_host_compiler)
|
||||
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
||||
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
|
|
|||
|
|
@ -2649,11 +2649,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
} \
|
||||
}
|
||||
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
||||
|
|
@ -2661,6 +2663,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
#endif
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
||||
|
|
@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||
|
||||
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
|
||||
if (k->type == GGML_TYPE_F32) {
|
||||
k_stride /= 4;
|
||||
}
|
||||
if (v->type == GGML_TYPE_F32) {
|
||||
v_stride /= 4;
|
||||
}
|
||||
|
||||
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
|
||||
bool aligned = (KV % alignment) == 0 &&
|
||||
|
|
@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
}
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
// supported in scalar and coopmat2 paths
|
||||
|
|
|
|||
|
|
@ -1,6 +1,18 @@
|
|||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
|
||||
vec4 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const vec4 v = bl.block;
|
||||
const uint idx = coordInBlock[1];
|
||||
const f16vec4 vf16 = f16vec4(v);
|
||||
return vf16[idx];
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
|
||||
block_q4_0_packed16 block;
|
||||
};
|
||||
|
|
@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
|||
#define dequantFuncA dequantFuncIQ4_NL
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define dequantFuncA dequantFuncMXFP4
|
||||
#elif defined(DATA_A_F32)
|
||||
#define dequantFuncA dequantFuncF32
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
|
|||
|
||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
|
||||
#elif defined(A_TYPE_PACKED16)
|
||||
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
#undef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 4
|
||||
#define BLOCK_BYTE_SIZE 16
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
// iqs is currently always zero in the flash attention shaders
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
return k_packed.k_data_packed[a_offset + ib];
|
||||
} else {
|
||||
return v_packed.v_data_packed[a_offset + ib];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
|
|
|
|||
|
|
@ -313,12 +313,12 @@ void main() {
|
|||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
|
||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC2 cache_b[TN];
|
||||
FLOAT_TYPE_VEC2 cache_b;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -360,20 +360,22 @@ void main() {
|
|||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
|
||||
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -388,8 +390,9 @@ void main() {
|
|||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||
sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -463,14 +466,21 @@ void main() {
|
|||
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||
#ifdef MUL_MAT_ID
|
||||
if (dr_warp + cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + 2 * cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
if (dr_warp + 2 * cr + 1 < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
|
||||
}
|
||||
#else
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
}
|
||||
|
|
|
|||
|
|
@ -611,9 +611,6 @@ void process_shaders() {
|
|||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "f32") {
|
||||
continue;
|
||||
}
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
|
|
@ -630,7 +627,7 @@ void process_shaders() {
|
|||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
|
|
@ -639,7 +636,7 @@ void process_shaders() {
|
|||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
|
|
|
|||
|
|
@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
||||
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
|
||||
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
|
||||
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
|
||||
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
|
||||
const char * swa_type_str = "unknown";
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
|
||||
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
|
||||
};
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
||||
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
||||
|
|
@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||
const int64_t n_kv = ubatch->n_tokens;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(kq_mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
||||
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
const llama_pos p1 = ubatch->pos[i1];
|
||||
|
||||
float * data = (float *) kq_mask->data;
|
||||
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
|
||||
|
||||
// [TAG_NO_CACHE_ISWA]
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
|
||||
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
||||
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
||||
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
||||
const llama_pos p0 = ubatch->pos[i0];
|
||||
|
||||
// mask different sequences
|
||||
if (s0 != s1) {
|
||||
continue; // skip different sequences
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
|
||||
continue; // skip future tokens for causal attention
|
||||
// mask future tokens
|
||||
if (cparams.causal_attn && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
|
||||
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
|
||||
// continue; // skip masked tokens for SWA
|
||||
//}
|
||||
|
||||
// TODO: reimplement this like in llama_kv_cache_unified
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
// apply SWA if any
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
}
|
||||
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
GGML_ASSERT(self_kq_mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
|
||||
float * data = (float *) self_kq_mask->data;
|
||||
|
||||
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
|
||||
|
||||
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
|
||||
}
|
||||
}
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(self_kq_mask_swa);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
||||
|
||||
float * data = (float *) self_kq_mask_swa->data;
|
||||
|
||||
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
|
||||
|
||||
fill_mask(data, hparams.n_swa, hparams.swa_type);
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1299,12 +1321,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
|
||||
const auto n_kv = k->ne[1];
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
// TODO: replace hardcoded padding with ggml-provided padding
|
||||
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
|
||||
if (cparams.flash_attn && kq_b == nullptr) {
|
||||
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
||||
|
||||
if (v_trans) {
|
||||
|
|
@ -1419,10 +1438,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
||||
|
||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->kq_mask);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
} else {
|
||||
inp->self_kq_mask_swa = nullptr;
|
||||
inp->self_kq_mask_swa_cnv = nullptr;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
|
@ -1447,7 +1476,9 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
// [TAG_NO_CACHE_PAD]
|
||||
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
||||
|
|
|
|||
|
|
@ -257,10 +257,14 @@ public:
|
|||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
|
||||
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
|
||||
// n_tokens == n_batch
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
|
|
|||
|
|
@ -11358,8 +11358,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
|||
}
|
||||
};
|
||||
|
||||
struct llm_build_gemma_embedding_iswa : public llm_graph_context {
|
||||
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
struct llm_build_gemma_embedding : public llm_graph_context {
|
||||
llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
|
@ -11376,8 +11376,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
|
|||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
auto * inp_attn = build_attn_inp_no_cache();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
|
|
@ -19378,7 +19377,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
case LLM_ARCH_NOMIC_BERT_MOE:
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
|
|
@ -19671,7 +19670,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
} break;
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
{
|
||||
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
|
||||
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -312,6 +312,7 @@ struct llama_model * llama_model_load_from_splits(
|
|||
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
splits.reserve(n_paths);
|
||||
for (size_t i = 0; i < n_paths; ++i) {
|
||||
splits.push_back(paths[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6779,7 +6779,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
|
||||
// run fewer test cases permuted
|
||||
|
|
@ -6911,7 +6911,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
}
|
||||
|
||||
// qwen3-30b-a3b
|
||||
for (int bs : {1, 4, 8, 32, 64, 128, 512}) {
|
||||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
|
||||
|
|
@ -6919,6 +6919,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
}
|
||||
}
|
||||
|
||||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// gpt-oss-20b
|
||||
for (int bs : {1, 4, 8, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1585,23 +1585,31 @@ struct server_prompt_cache {
|
|||
}
|
||||
}
|
||||
|
||||
// average size per token
|
||||
const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
|
||||
|
||||
// dynamically increase the token limit if it can fit in the memory limit
|
||||
const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size/size_per_token) : limit_tokens;
|
||||
|
||||
if (limit_tokens > 0) {
|
||||
while (states.size() > 1 && n_tokens() > limit_tokens) {
|
||||
while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
|
||||
if (states.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
|
||||
SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
|
||||
limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
|
||||
|
||||
states.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n",
|
||||
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens);
|
||||
SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
|
||||
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
|
||||
|
||||
for (const auto & state : states) {
|
||||
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
||||
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
|
||||
(const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@
|
|||
"eslint-plugin-svelte": "^3.0.0",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^16.0.0",
|
||||
"mdast": "^3.0.0",
|
||||
"mdsvex": "^0.12.3",
|
||||
"playwright": "^1.53.0",
|
||||
"prettier": "^3.4.2",
|
||||
|
|
@ -66,6 +67,7 @@
|
|||
"tw-animate-css": "^1.3.5",
|
||||
"typescript": "^5.0.0",
|
||||
"typescript-eslint": "^8.20.0",
|
||||
"unified": "^11.0.5",
|
||||
"uuid": "^13.0.0",
|
||||
"vite": "^7.0.4",
|
||||
"vite-plugin-devtools-json": "^0.2.0",
|
||||
|
|
@ -2128,6 +2130,66 @@
|
|||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/core": {
|
||||
"version": "1.4.3",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@emnapi/wasi-threads": "1.0.2",
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/runtime": {
|
||||
"version": "1.4.3",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/wasi-threads": {
|
||||
"version": "1.0.2",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@napi-rs/wasm-runtime": {
|
||||
"version": "0.2.11",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@emnapi/core": "^1.4.3",
|
||||
"@emnapi/runtime": "^1.4.3",
|
||||
"@tybys/wasm-util": "^0.9.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@tybys/wasm-util": {
|
||||
"version": "0.9.0",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/tslib": {
|
||||
"version": "2.8.0",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "0BSD",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-win32-arm64-msvc": {
|
||||
"version": "4.1.11",
|
||||
"resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.11.tgz",
|
||||
|
|
@ -4946,6 +5008,13 @@
|
|||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/mdast": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/mdast/-/mdast-3.0.0.tgz",
|
||||
"integrity": "sha512-xySmf8g4fPKMeC07jXGz971EkLbWAJ83s4US2Tj9lEdnZ142UP5grN73H1Xd3HzrdbU5o9GYYP/y8F9ZSwLE9g==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/mdast-util-find-and-replace": {
|
||||
"version": "3.0.2",
|
||||
"resolved": "https://registry.npmjs.org/mdast-util-find-and-replace/-/mdast-util-find-and-replace-3.0.2.tgz",
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@
|
|||
"eslint-plugin-svelte": "^3.0.0",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^16.0.0",
|
||||
"mdast": "^3.0.0",
|
||||
"mdsvex": "^0.12.3",
|
||||
"playwright": "^1.53.0",
|
||||
"prettier": "^3.4.2",
|
||||
|
|
@ -68,6 +69,7 @@
|
|||
"tw-animate-css": "^1.3.5",
|
||||
"typescript": "^5.0.0",
|
||||
"typescript-eslint": "^8.20.0",
|
||||
"unified": "^11.0.5",
|
||||
"uuid": "^13.0.0",
|
||||
"vite": "^7.0.4",
|
||||
"vite-plugin-devtools-json": "^0.2.0",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
ChatMessages,
|
||||
ChatProcessingInfo,
|
||||
EmptyFileAlertDialog,
|
||||
ChatErrorDialog,
|
||||
ServerErrorSplash,
|
||||
ServerInfo,
|
||||
ServerLoadingSplash,
|
||||
|
|
@ -22,10 +23,11 @@
|
|||
activeMessages,
|
||||
activeConversation,
|
||||
deleteConversation,
|
||||
dismissErrorDialog,
|
||||
errorDialog,
|
||||
isLoading,
|
||||
sendMessage,
|
||||
stopGeneration,
|
||||
setMaxContextError
|
||||
stopGeneration
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import {
|
||||
supportsVision,
|
||||
|
|
@ -34,7 +36,6 @@
|
|||
serverWarning,
|
||||
serverStore
|
||||
} from '$lib/stores/server.svelte';
|
||||
import { contextService } from '$lib/services';
|
||||
import { parseFilesToMessageExtras } from '$lib/utils/convert-files-to-extra';
|
||||
import { isFileTypeSupported } from '$lib/utils/file-type';
|
||||
import { filterFilesByModalities } from '$lib/utils/modality-file-validation';
|
||||
|
|
@ -79,6 +80,7 @@
|
|||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
|
||||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
|
||||
async function handleDeleteConfirm() {
|
||||
|
|
@ -105,6 +107,12 @@
|
|||
}
|
||||
}
|
||||
|
||||
function handleErrorDialogOpenChange(open: boolean) {
|
||||
if (!open) {
|
||||
dismissErrorDialog();
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
}
|
||||
|
|
@ -183,21 +191,6 @@
|
|||
|
||||
const extras = result?.extras;
|
||||
|
||||
// Check context limit using real-time slots data
|
||||
const contextCheck = await contextService.checkContextLimit();
|
||||
|
||||
if (contextCheck && contextCheck.wouldExceed) {
|
||||
const errorMessage = contextService.getContextErrorMessage(contextCheck);
|
||||
|
||||
setMaxContextError({
|
||||
message: errorMessage,
|
||||
estimatedTokens: contextCheck.currentUsage,
|
||||
maxContext: contextCheck.maxContext
|
||||
});
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Enable autoscroll for user-initiated message sending
|
||||
userScrolledUp = false;
|
||||
autoScrollEnabled = true;
|
||||
|
|
@ -461,6 +454,13 @@
|
|||
}}
|
||||
/>
|
||||
|
||||
<ChatErrorDialog
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? 'server'}
|
||||
/>
|
||||
|
||||
<style>
|
||||
.conversation-chat-form {
|
||||
position: relative;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
<script lang="ts">
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { AlertTriangle, TimerOff } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
type: 'timeout' | 'server';
|
||||
message: string;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
let { open = $bindable(), type, message, onOpenChange }: Props = $props();
|
||||
|
||||
const isTimeout = $derived(type === 'timeout');
|
||||
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
|
||||
const description = $derived(
|
||||
isTimeout
|
||||
? 'The request did not receive a response from the server before timing out.'
|
||||
: 'The server responded with an error message. Review the details below.'
|
||||
);
|
||||
const iconClass = $derived(isTimeout ? 'text-destructive' : 'text-amber-500');
|
||||
const badgeClass = $derived(
|
||||
isTimeout
|
||||
? 'border-destructive/40 bg-destructive/10 text-destructive'
|
||||
: 'border-amber-500/40 bg-amber-500/10 text-amber-600 dark:text-amber-400'
|
||||
);
|
||||
|
||||
function handleOpenChange(newOpen: boolean) {
|
||||
open = newOpen;
|
||||
onOpenChange?.(newOpen);
|
||||
}
|
||||
</script>
|
||||
|
||||
<AlertDialog.Root {open} onOpenChange={handleOpenChange}>
|
||||
<AlertDialog.Content>
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title class="flex items-center gap-2">
|
||||
{#if isTimeout}
|
||||
<TimerOff class={`h-5 w-5 ${iconClass}`} />
|
||||
{:else}
|
||||
<AlertTriangle class={`h-5 w-5 ${iconClass}`} />
|
||||
{/if}
|
||||
|
||||
{title}
|
||||
</AlertDialog.Title>
|
||||
|
||||
<AlertDialog.Description>
|
||||
{description}
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
|
||||
<div class={`rounded-lg border px-4 py-3 text-sm ${badgeClass}`}>
|
||||
<p class="font-medium">{message}</p>
|
||||
</div>
|
||||
|
||||
<AlertDialog.Footer>
|
||||
<AlertDialog.Action onclick={() => handleOpenChange(false)}>Close</AlertDialog.Action>
|
||||
</AlertDialog.Footer>
|
||||
</AlertDialog.Content>
|
||||
</AlertDialog.Root>
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
<script lang="ts">
|
||||
import { AlertTriangle } from '@lucide/svelte';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { maxContextError, clearMaxContextError } from '$lib/stores/chat.svelte';
|
||||
</script>
|
||||
|
||||
<AlertDialog.Root
|
||||
open={maxContextError() !== null}
|
||||
onOpenChange={(open) => !open && clearMaxContextError()}
|
||||
>
|
||||
<AlertDialog.Content>
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title class="flex items-center gap-2">
|
||||
<AlertTriangle class="h-5 w-5 text-destructive" />
|
||||
|
||||
Message Too Long
|
||||
</AlertDialog.Title>
|
||||
|
||||
<AlertDialog.Description>
|
||||
Your message exceeds the model's context window and cannot be processed.
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
|
||||
{#if maxContextError()}
|
||||
<div class="space-y-3 text-sm">
|
||||
<div class="rounded-lg bg-muted p-3">
|
||||
<div class="mb-2 font-medium">Token Usage:</div>
|
||||
|
||||
<div class="space-y-1 text-muted-foreground">
|
||||
<div>
|
||||
Estimated tokens:
|
||||
|
||||
<span class="font-mono">
|
||||
{maxContextError()?.estimatedTokens.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
Context window:
|
||||
|
||||
<span class="font-mono">
|
||||
{maxContextError()?.maxContext.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div class="mb-2 font-medium">Suggestions:</div>
|
||||
|
||||
<ul class="list-inside list-disc space-y-1 text-muted-foreground">
|
||||
<li>Shorten your message</li>
|
||||
|
||||
<li>Remove some file attachments</li>
|
||||
|
||||
<li>Start a new conversation</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<AlertDialog.Footer>
|
||||
<AlertDialog.Action onclick={() => clearMaxContextError()}>Got it</AlertDialog.Action>
|
||||
</AlertDialog.Footer>
|
||||
</AlertDialog.Content>
|
||||
</AlertDialog.Root>
|
||||
|
|
@ -30,12 +30,11 @@ export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
|
|||
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
|
||||
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
|
||||
|
||||
export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte';
|
||||
export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte';
|
||||
|
||||
export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte';
|
||||
|
||||
export { default as MaximumContextAlertDialog } from './dialogs/MaximumContextAlertDialog.svelte';
|
||||
|
||||
export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte';
|
||||
|
||||
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
import { mode } from 'mode-watcher';
|
||||
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
|
|
@ -50,36 +51,59 @@
|
|||
.use(remarkGfm) // GitHub Flavored Markdown
|
||||
.use(remarkMath) // Parse $inline$ and $$block$$ math
|
||||
.use(remarkBreaks) // Convert line breaks to <br>
|
||||
.use(remarkRehype) // Convert to rehype (HTML AST)
|
||||
.use(remarkLiteralHtml) // Treat raw HTML as literal text with preserved indentation
|
||||
.use(remarkRehype) // Convert Markdown AST to rehype
|
||||
.use(rehypeKatex) // Render math using KaTeX
|
||||
.use(rehypeHighlight) // Add syntax highlighting
|
||||
.use(rehypeStringify); // Convert to HTML string
|
||||
});
|
||||
|
||||
function enhanceLinks(html: string): string {
|
||||
if (!html.includes('<a')) {
|
||||
return html;
|
||||
}
|
||||
|
||||
const tempDiv = document.createElement('div');
|
||||
tempDiv.innerHTML = html;
|
||||
|
||||
// Make all links open in new tabs
|
||||
const linkElements = tempDiv.querySelectorAll('a[href]');
|
||||
let mutated = false;
|
||||
|
||||
for (const link of linkElements) {
|
||||
const target = link.getAttribute('target');
|
||||
const rel = link.getAttribute('rel');
|
||||
|
||||
if (target !== '_blank' || rel !== 'noopener noreferrer') {
|
||||
mutated = true;
|
||||
}
|
||||
|
||||
link.setAttribute('target', '_blank');
|
||||
link.setAttribute('rel', 'noopener noreferrer');
|
||||
}
|
||||
|
||||
return tempDiv.innerHTML;
|
||||
return mutated ? tempDiv.innerHTML : html;
|
||||
}
|
||||
|
||||
function enhanceCodeBlocks(html: string): string {
|
||||
if (!html.includes('<pre')) {
|
||||
return html;
|
||||
}
|
||||
|
||||
const tempDiv = document.createElement('div');
|
||||
tempDiv.innerHTML = html;
|
||||
|
||||
const preElements = tempDiv.querySelectorAll('pre');
|
||||
let mutated = false;
|
||||
|
||||
for (const [index, pre] of Array.from(preElements).entries()) {
|
||||
const codeElement = pre.querySelector('code');
|
||||
|
||||
if (!codeElement) continue;
|
||||
if (!codeElement) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mutated = true;
|
||||
|
||||
let language = 'text';
|
||||
const classList = Array.from(codeElement.classList);
|
||||
|
|
@ -127,7 +151,7 @@
|
|||
pre.parentNode?.replaceChild(wrapper, pre);
|
||||
}
|
||||
|
||||
return tempDiv.innerHTML;
|
||||
return mutated ? tempDiv.innerHTML : html;
|
||||
}
|
||||
|
||||
async function processMarkdown(text: string): Promise<string> {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
export const LINE_BREAK = /\r?\n/;
|
||||
|
||||
export const PHRASE_PARENTS = new Set([
|
||||
'paragraph',
|
||||
'heading',
|
||||
'emphasis',
|
||||
'strong',
|
||||
'delete',
|
||||
'link',
|
||||
'linkReference',
|
||||
'tableCell'
|
||||
]);
|
||||
|
||||
export const NBSP = '\u00a0';
|
||||
export const TAB_AS_SPACES = NBSP.repeat(4);
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
import type { Plugin } from 'unified';
|
||||
import { visit } from 'unist-util-visit';
|
||||
import type { Break, Content, Paragraph, PhrasingContent, Root, Text } from 'mdast';
|
||||
import { LINE_BREAK, NBSP, PHRASE_PARENTS, TAB_AS_SPACES } from '$lib/constants/literal-html';
|
||||
|
||||
/**
|
||||
* remark plugin that rewrites raw HTML nodes into plain-text equivalents.
|
||||
*
|
||||
* remark parses inline HTML into `html` nodes even when we do not want to render
|
||||
* them. We turn each of those nodes into regular text (plus `<br>` break markers)
|
||||
* so the downstream rehype pipeline escapes the characters instead of executing
|
||||
* them. Leading spaces and tab characters are converted to non‑breaking spaces to
|
||||
* keep indentation identical to the original author input.
|
||||
*/
|
||||
|
||||
function preserveIndent(line: string): string {
|
||||
let index = 0;
|
||||
let output = '';
|
||||
|
||||
while (index < line.length) {
|
||||
const char = line[index];
|
||||
|
||||
if (char === ' ') {
|
||||
output += NBSP;
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (char === '\t') {
|
||||
output += TAB_AS_SPACES;
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
return output + line.slice(index);
|
||||
}
|
||||
|
||||
function createLiteralChildren(value: string): PhrasingContent[] {
|
||||
const lines = value.split(LINE_BREAK);
|
||||
const nodes: PhrasingContent[] = [];
|
||||
|
||||
for (const [lineIndex, rawLine] of lines.entries()) {
|
||||
if (lineIndex > 0) {
|
||||
nodes.push({ type: 'break' } as Break as unknown as PhrasingContent);
|
||||
}
|
||||
|
||||
nodes.push({
|
||||
type: 'text',
|
||||
value: preserveIndent(rawLine)
|
||||
} as Text as unknown as PhrasingContent);
|
||||
}
|
||||
|
||||
if (!nodes.length) {
|
||||
nodes.push({ type: 'text', value: '' } as Text as unknown as PhrasingContent);
|
||||
}
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
export const remarkLiteralHtml: Plugin<[], Root> = () => {
|
||||
return (tree) => {
|
||||
visit(tree, 'html', (node, index, parent) => {
|
||||
if (!parent || typeof index !== 'number') {
|
||||
return;
|
||||
}
|
||||
|
||||
const replacement = createLiteralChildren(node.value);
|
||||
|
||||
if (!PHRASE_PARENTS.has(parent.type as string)) {
|
||||
const paragraph: Paragraph = {
|
||||
type: 'paragraph',
|
||||
children: replacement as Paragraph['children'],
|
||||
data: { literalHtml: true }
|
||||
};
|
||||
|
||||
const siblings = parent.children as unknown as Content[];
|
||||
siblings.splice(index, 1, paragraph as unknown as Content);
|
||||
|
||||
if (index > 0) {
|
||||
const previous = siblings[index - 1] as Paragraph | undefined;
|
||||
|
||||
if (
|
||||
previous?.type === 'paragraph' &&
|
||||
(previous.data as { literalHtml?: boolean } | undefined)?.literalHtml
|
||||
) {
|
||||
const prevChildren = previous.children as unknown as PhrasingContent[];
|
||||
|
||||
if (prevChildren.length) {
|
||||
const lastChild = prevChildren[prevChildren.length - 1];
|
||||
|
||||
if (lastChild.type !== 'break') {
|
||||
prevChildren.push({
|
||||
type: 'break'
|
||||
} as Break as unknown as PhrasingContent);
|
||||
}
|
||||
}
|
||||
|
||||
prevChildren.push(...(paragraph.children as unknown as PhrasingContent[]));
|
||||
|
||||
siblings.splice(index, 1);
|
||||
|
||||
return index;
|
||||
}
|
||||
}
|
||||
|
||||
return index + 1;
|
||||
}
|
||||
|
||||
(parent.children as unknown as PhrasingContent[]).splice(
|
||||
index,
|
||||
1,
|
||||
...(replacement as unknown as PhrasingContent[])
|
||||
);
|
||||
|
||||
return index + replacement.length;
|
||||
});
|
||||
};
|
||||
};
|
||||
|
|
@ -13,7 +13,7 @@ import { slotsService } from './slots';
|
|||
* - Manages streaming and non-streaming response parsing
|
||||
* - Provides request abortion capabilities
|
||||
* - Converts database messages to API format
|
||||
* - Handles error translation and context detection
|
||||
* - Handles error translation for server responses
|
||||
*
|
||||
* - **ChatStore**: Stateful orchestration and UI state management
|
||||
* - Uses ChatService for all AI model communication
|
||||
|
|
@ -26,7 +26,6 @@ import { slotsService } from './slots';
|
|||
* - Streaming response handling with real-time callbacks
|
||||
* - Reasoning content extraction and processing
|
||||
* - File attachment processing (images, PDFs, audio, text)
|
||||
* - Context error detection and reporting
|
||||
* - Request lifecycle management (abort, cleanup)
|
||||
*/
|
||||
export class ChatService {
|
||||
|
|
@ -209,10 +208,13 @@ export class ChatService {
|
|||
userFriendlyError = new Error(
|
||||
'Unable to connect to server - please check if the server is running'
|
||||
);
|
||||
userFriendlyError.name = 'NetworkError';
|
||||
} else if (error.message.includes('ECONNREFUSED')) {
|
||||
userFriendlyError = new Error('Connection refused - server may be offline');
|
||||
userFriendlyError.name = 'NetworkError';
|
||||
} else if (error.message.includes('ETIMEDOUT')) {
|
||||
userFriendlyError = new Error('Request timeout - server may be overloaded');
|
||||
userFriendlyError = new Error('Request timed out - the server took too long to respond');
|
||||
userFriendlyError.name = 'TimeoutError';
|
||||
} else {
|
||||
userFriendlyError = error;
|
||||
}
|
||||
|
|
@ -262,6 +264,7 @@ export class ChatService {
|
|||
let fullReasoningContent = '';
|
||||
let hasReceivedData = false;
|
||||
let lastTimings: ChatMessageTimings | undefined;
|
||||
let streamFinished = false;
|
||||
|
||||
try {
|
||||
let chunk = '';
|
||||
|
|
@ -277,18 +280,8 @@ export class ChatService {
|
|||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6);
|
||||
if (data === '[DONE]') {
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
return;
|
||||
}
|
||||
|
||||
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
|
||||
|
||||
return;
|
||||
streamFinished = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
|
|
@ -326,13 +319,13 @@ export class ChatService {
|
|||
}
|
||||
}
|
||||
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
return;
|
||||
if (streamFinished) {
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
|
||||
}
|
||||
} catch (error) {
|
||||
const err = error instanceof Error ? error : new Error('Stream error');
|
||||
|
|
@ -368,12 +361,8 @@ export class ChatService {
|
|||
const responseText = await response.text();
|
||||
|
||||
if (!responseText.trim()) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
throw contextError;
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
const data: ApiChatCompletionResponse = JSON.parse(responseText);
|
||||
|
|
@ -385,22 +374,14 @@ export class ChatService {
|
|||
}
|
||||
|
||||
if (!content.trim()) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
throw contextError;
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
onComplete?.(content, reasoningContent);
|
||||
|
||||
return content;
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === 'ContextError') {
|
||||
throw error;
|
||||
}
|
||||
|
||||
const err = error instanceof Error ? error : new Error('Parse error');
|
||||
|
||||
onError?.(err);
|
||||
|
|
@ -594,37 +575,19 @@ export class ChatService {
|
|||
const errorText = await response.text();
|
||||
const errorData: ApiErrorResponse = JSON.parse(errorText);
|
||||
|
||||
if (errorData.error?.type === 'exceed_context_size_error') {
|
||||
const contextError = errorData.error as ApiContextSizeError;
|
||||
const error = new Error(contextError.message);
|
||||
error.name = 'ContextError';
|
||||
// Attach structured context information
|
||||
(
|
||||
error as Error & {
|
||||
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
|
||||
}
|
||||
).contextInfo = {
|
||||
promptTokens: contextError.n_prompt_tokens,
|
||||
maxContext: contextError.n_ctx,
|
||||
estimatedTokens: contextError.n_prompt_tokens
|
||||
};
|
||||
return error;
|
||||
}
|
||||
|
||||
// Fallback for other error types
|
||||
const message = errorData.error?.message || 'Unknown server error';
|
||||
return new Error(message);
|
||||
const error = new Error(message);
|
||||
error.name = response.status === 400 ? 'ServerError' : 'HttpError';
|
||||
|
||||
return error;
|
||||
} catch {
|
||||
// If we can't parse the error response, return a generic error
|
||||
return new Error(`Server error (${response.status}): ${response.statusText}`);
|
||||
const fallback = new Error(`Server error (${response.status}): ${response.statusText}`);
|
||||
fallback.name = 'HttpError';
|
||||
return fallback;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the processing state with timing information from the server response
|
||||
* @param timings - Timing data from the API response
|
||||
* @param promptProgress - Progress data from the API response
|
||||
*/
|
||||
private updateProcessingState(
|
||||
timings?: ChatMessageTimings,
|
||||
promptProgress?: ChatMessagePromptProgress
|
||||
|
|
|
|||
|
|
@ -1,102 +0,0 @@
|
|||
import { slotsService } from './slots';
|
||||
|
||||
export interface ContextCheckResult {
|
||||
wouldExceed: boolean;
|
||||
currentUsage: number;
|
||||
maxContext: number;
|
||||
availableTokens: number;
|
||||
reservedTokens: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* ContextService - Context window management and limit checking
|
||||
*
|
||||
* This service provides context window monitoring and limit checking using real-time
|
||||
* server data from the slots service. It helps prevent context overflow by tracking
|
||||
* current usage and calculating available space for new content.
|
||||
*
|
||||
* **Architecture & Relationships:**
|
||||
* - **ContextService** (this class): Context limit monitoring
|
||||
* - Uses SlotsService for real-time context usage data
|
||||
* - Calculates available tokens with configurable reserves
|
||||
* - Provides context limit checking and error messaging
|
||||
* - Helps prevent context window overflow
|
||||
*
|
||||
* - **SlotsService**: Provides current context usage from server slots
|
||||
* - **ChatStore**: Uses context checking before sending messages
|
||||
* - **UI Components**: Display context usage warnings and limits
|
||||
*
|
||||
* **Key Features:**
|
||||
* - **Real-time Context Checking**: Uses live server data for accuracy
|
||||
* - **Token Reservation**: Reserves tokens for response generation
|
||||
* - **Limit Detection**: Prevents context window overflow
|
||||
* - **Usage Reporting**: Detailed context usage statistics
|
||||
* - **Error Messaging**: User-friendly context limit messages
|
||||
* - **Configurable Reserves**: Adjustable token reservation for responses
|
||||
*
|
||||
* **Context Management:**
|
||||
* - Monitors current context usage from active slots
|
||||
* - Calculates available space considering reserved tokens
|
||||
* - Provides early warning before context limits are reached
|
||||
* - Helps optimize conversation length and content
|
||||
*/
|
||||
export class ContextService {
|
||||
private reserveTokens: number;
|
||||
|
||||
constructor(reserveTokens = 512) {
|
||||
this.reserveTokens = reserveTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the context limit would be exceeded
|
||||
*
|
||||
* @returns {Promise<ContextCheckResult | null>} Promise that resolves to the context check result or null if an error occurs
|
||||
*/
|
||||
async checkContextLimit(): Promise<ContextCheckResult | null> {
|
||||
try {
|
||||
const currentState = await slotsService.getCurrentState();
|
||||
|
||||
if (!currentState) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const maxContext = currentState.contextTotal;
|
||||
const currentUsage = currentState.contextUsed;
|
||||
const availableTokens = maxContext - currentUsage - this.reserveTokens;
|
||||
const wouldExceed = availableTokens <= 0;
|
||||
|
||||
return {
|
||||
wouldExceed,
|
||||
currentUsage,
|
||||
maxContext,
|
||||
availableTokens: Math.max(0, availableTokens),
|
||||
reservedTokens: this.reserveTokens
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn('Error checking context limit:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a formatted error message for context limit exceeded
|
||||
*
|
||||
* @param {ContextCheckResult} result - Context check result
|
||||
* @returns {string} Formatted error message
|
||||
*/
|
||||
getContextErrorMessage(result: ContextCheckResult): string {
|
||||
const usagePercent = Math.round((result.currentUsage / result.maxContext) * 100);
|
||||
return `Context window is nearly full. Current usage: ${result.currentUsage.toLocaleString()}/${result.maxContext.toLocaleString()} tokens (${usagePercent}%). Available space: ${result.availableTokens.toLocaleString()} tokens (${result.reservedTokens} reserved for response).`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of tokens to reserve for response generation
|
||||
*
|
||||
* @param {number} tokens - Number of tokens to reserve
|
||||
*/
|
||||
setReserveTokens(tokens: number): void {
|
||||
this.reserveTokens = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
export const contextService = new ContextService();
|
||||
|
|
@ -1,3 +1,2 @@
|
|||
export { chatService } from './chat';
|
||||
export { contextService } from './context';
|
||||
export { slotsService } from './slots';
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ import type { ExportedConversations } from '$lib/types/database';
|
|||
* - Conversation branching for exploring different response paths
|
||||
* - Streaming AI responses with real-time content updates
|
||||
* - File attachment support (images, PDFs, text files, audio)
|
||||
* - Context window management with error recovery
|
||||
* - Partial response saving when generation is interrupted
|
||||
* - Message editing with automatic response regeneration
|
||||
*/
|
||||
|
|
@ -48,11 +47,9 @@ class ChatStore {
|
|||
activeMessages = $state<DatabaseMessage[]>([]);
|
||||
conversations = $state<DatabaseConversation[]>([]);
|
||||
currentResponse = $state('');
|
||||
errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null);
|
||||
isInitialized = $state(false);
|
||||
isLoading = $state(false);
|
||||
maxContextError = $state<{ message: string; estimatedTokens: number; maxContext: number } | null>(
|
||||
null
|
||||
);
|
||||
titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise<boolean>;
|
||||
|
||||
constructor() {
|
||||
|
|
@ -69,8 +66,6 @@ class ChatStore {
|
|||
try {
|
||||
await this.loadConversations();
|
||||
|
||||
this.maxContextError = null;
|
||||
|
||||
this.isInitialized = true;
|
||||
} catch (error) {
|
||||
console.error('Failed to initialize chat store:', error);
|
||||
|
|
@ -99,8 +94,6 @@ class ChatStore {
|
|||
this.activeConversation = conversation;
|
||||
this.activeMessages = [];
|
||||
|
||||
this.maxContextError = null;
|
||||
|
||||
await goto(`#/chat/${conversation.id}`);
|
||||
|
||||
return conversation.id;
|
||||
|
|
@ -133,8 +126,6 @@ class ChatStore {
|
|||
this.activeMessages = await DatabaseStore.getConversationMessages(convId);
|
||||
}
|
||||
|
||||
this.maxContextError = null;
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error('Failed to load conversation:', error);
|
||||
|
|
@ -418,56 +409,6 @@ class ChatStore {
|
|||
return;
|
||||
}
|
||||
|
||||
if (error.name === 'ContextError') {
|
||||
console.warn('Context error detected:', error.message);
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
|
||||
const messageIndex = this.activeMessages.findIndex(
|
||||
(m: DatabaseMessage) => m.id === assistantMessage.id
|
||||
);
|
||||
|
||||
if (messageIndex !== -1) {
|
||||
this.activeMessages.splice(messageIndex, 1);
|
||||
DatabaseStore.deleteMessage(assistantMessage.id).catch(console.error);
|
||||
}
|
||||
|
||||
// Use structured context info from new exceed_context_size_error format if available
|
||||
const contextInfo = (
|
||||
error as Error & {
|
||||
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
|
||||
}
|
||||
).contextInfo;
|
||||
let estimatedTokens = 0;
|
||||
let maxContext = serverStore.serverProps?.default_generation_settings.n_ctx || 8192;
|
||||
|
||||
if (contextInfo) {
|
||||
// Use precise token counts from server response
|
||||
estimatedTokens = contextInfo.promptTokens;
|
||||
maxContext = contextInfo.maxContext;
|
||||
} else {
|
||||
// Fallback to estimation for older error format
|
||||
try {
|
||||
// Rough estimation: ~4 characters per token
|
||||
const messageContent = JSON.stringify(messages);
|
||||
estimatedTokens = Math.ceil(messageContent.length / 4);
|
||||
} catch {
|
||||
estimatedTokens = 0;
|
||||
}
|
||||
}
|
||||
|
||||
this.maxContextError = {
|
||||
message: error.message,
|
||||
estimatedTokens,
|
||||
maxContext
|
||||
};
|
||||
|
||||
if (onError) {
|
||||
onError(error);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
console.error('Streaming error:', error);
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
|
|
@ -477,9 +418,19 @@ class ChatStore {
|
|||
);
|
||||
|
||||
if (messageIndex !== -1) {
|
||||
this.activeMessages[messageIndex].content = `Error: ${error.message}`;
|
||||
const [failedMessage] = this.activeMessages.splice(messageIndex, 1);
|
||||
|
||||
if (failedMessage) {
|
||||
DatabaseStore.deleteMessage(failedMessage.id).catch((cleanupError) => {
|
||||
console.error('Failed to remove assistant message after error:', cleanupError);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
|
||||
this.showErrorDialog(dialogType, error.message);
|
||||
|
||||
if (onError) {
|
||||
onError(error);
|
||||
}
|
||||
|
|
@ -487,6 +438,14 @@ class ChatStore {
|
|||
});
|
||||
}
|
||||
|
||||
private showErrorDialog(type: 'timeout' | 'server', message: string): void {
|
||||
this.errorDialogState = { type, message };
|
||||
}
|
||||
|
||||
dismissErrorDialog(): void {
|
||||
this.errorDialogState = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if an error is an abort error (user cancelled operation)
|
||||
* @param error - The error to check
|
||||
|
|
@ -574,6 +533,7 @@ class ChatStore {
|
|||
return;
|
||||
}
|
||||
|
||||
this.errorDialogState = null;
|
||||
this.isLoading = true;
|
||||
this.currentResponse = '';
|
||||
|
||||
|
|
@ -603,37 +563,23 @@ class ChatStore {
|
|||
|
||||
const conversationContext = this.activeMessages.slice(0, -1);
|
||||
|
||||
await this.streamChatCompletion(
|
||||
conversationContext,
|
||||
assistantMessage,
|
||||
undefined,
|
||||
(error: Error) => {
|
||||
if (error.name === 'ContextError' && userMessage) {
|
||||
const userMessageIndex = this.findMessageIndex(userMessage.id);
|
||||
|
||||
if (userMessageIndex !== -1) {
|
||||
this.activeMessages.splice(userMessageIndex, 1);
|
||||
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
await this.streamChatCompletion(conversationContext, assistantMessage);
|
||||
} catch (error) {
|
||||
if (this.isAbortError(error)) {
|
||||
this.isLoading = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (error instanceof Error && error.name === 'ContextError' && userMessage) {
|
||||
const userMessageIndex = this.findMessageIndex(userMessage.id);
|
||||
if (userMessageIndex !== -1) {
|
||||
this.activeMessages.splice(userMessageIndex, 1);
|
||||
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
|
||||
}
|
||||
}
|
||||
|
||||
console.error('Failed to send message:', error);
|
||||
this.isLoading = false;
|
||||
if (!this.errorDialogState) {
|
||||
if (error instanceof Error) {
|
||||
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
this.showErrorDialog(dialogType, error.message);
|
||||
} else {
|
||||
this.showErrorDialog('server', 'Unknown error occurred while sending message');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -662,24 +608,6 @@ class ChatStore {
|
|||
this.currentResponse = '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the max context error state
|
||||
* Removes any displayed context limit warnings
|
||||
*/
|
||||
clearMaxContextError(): void {
|
||||
this.maxContextError = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the max context error state
|
||||
* @param error - The context error details or null to clear
|
||||
*/
|
||||
setMaxContextError(
|
||||
error: { message: string; estimatedTokens: number; maxContext: number } | null
|
||||
): void {
|
||||
this.maxContextError = error;
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves partial response if generation was interrupted
|
||||
* Preserves user's partial content and timing data when generation is stopped early
|
||||
|
|
@ -1250,7 +1178,6 @@ class ChatStore {
|
|||
this.activeMessages = [];
|
||||
this.currentResponse = '';
|
||||
this.isLoading = false;
|
||||
this.maxContextError = null;
|
||||
}
|
||||
|
||||
/** Refreshes active messages based on currNode after branch navigation */
|
||||
|
|
@ -1538,6 +1465,7 @@ class ChatStore {
|
|||
private async generateResponseForMessage(userMessageId: string): Promise<void> {
|
||||
if (!this.activeConversation) return;
|
||||
|
||||
this.errorDialogState = null;
|
||||
this.isLoading = true;
|
||||
this.currentResponse = '';
|
||||
|
||||
|
|
@ -1584,7 +1512,7 @@ export const activeMessages = () => chatStore.activeMessages;
|
|||
export const isLoading = () => chatStore.isLoading;
|
||||
export const currentResponse = () => chatStore.currentResponse;
|
||||
export const isInitialized = () => chatStore.isInitialized;
|
||||
export const maxContextError = () => chatStore.maxContextError;
|
||||
export const errorDialog = () => chatStore.errorDialogState;
|
||||
|
||||
export const createConversation = chatStore.createConversation.bind(chatStore);
|
||||
export const downloadConversation = chatStore.downloadConversation.bind(chatStore);
|
||||
|
|
@ -1592,9 +1520,9 @@ export const exportAllConversations = chatStore.exportAllConversations.bind(chat
|
|||
export const importConversations = chatStore.importConversations.bind(chatStore);
|
||||
export const deleteConversation = chatStore.deleteConversation.bind(chatStore);
|
||||
export const sendMessage = chatStore.sendMessage.bind(chatStore);
|
||||
export const dismissErrorDialog = chatStore.dismissErrorDialog.bind(chatStore);
|
||||
|
||||
export const gracefulStop = chatStore.gracefulStop.bind(chatStore);
|
||||
export const clearMaxContextError = chatStore.clearMaxContextError.bind(chatStore);
|
||||
export const setMaxContextError = chatStore.setMaxContextError.bind(chatStore);
|
||||
|
||||
// Branching operations
|
||||
export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatStore);
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ class ServerStore {
|
|||
errorMessage = 'Server not found - check server address';
|
||||
isOfflineLikeError = true;
|
||||
} else if (error.message.includes('ETIMEDOUT')) {
|
||||
errorMessage = 'Connection timeout - server may be overloaded';
|
||||
errorMessage = 'Request timed out - the server took too long to respond';
|
||||
isOfflineLikeError = true;
|
||||
} else if (error.message.includes('503')) {
|
||||
errorMessage = 'Server temporarily unavailable - try again shortly';
|
||||
|
|
|
|||
|
|
@ -1,11 +1,7 @@
|
|||
<script lang="ts">
|
||||
import '../app.css';
|
||||
import { page } from '$app/state';
|
||||
import {
|
||||
ChatSidebar,
|
||||
ConversationTitleUpdateDialog,
|
||||
MaximumContextAlertDialog
|
||||
} from '$lib/components/app';
|
||||
import { ChatSidebar, ConversationTitleUpdateDialog } from '$lib/components/app';
|
||||
import {
|
||||
activeMessages,
|
||||
isLoading,
|
||||
|
|
@ -145,8 +141,6 @@
|
|||
|
||||
<Toaster richColors />
|
||||
|
||||
<MaximumContextAlertDialog />
|
||||
|
||||
<ConversationTitleUpdateDialog
|
||||
bind:open={titleUpdateDialogOpen}
|
||||
currentTitle={titleUpdateCurrentTitle}
|
||||
|
|
|
|||
Loading…
Reference in New Issue