vulkan: Fix GGML_VULKAN_CHECK_RESULTS to better handle fusion (#16919)

This commit is contained in:
Jeff Bolz 2025-11-05 12:51:03 -06:00 committed by GitHub
parent 5886f4f545
commit a44d77126c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 318 additions and 326 deletions

View File

@ -14104,20 +14104,11 @@ size_t comp_size;
size_t comp_nb[GGML_MAX_DIMS];
size_t check_counter = 0;
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
return;
}
bool fused_rms_norm_mul = false;
int rms_norm_idx = -1;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
check_counter++;
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return;
@ -14125,9 +14116,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")");
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
struct ggml_init_params iparams = {
/*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
/*.mem_buffer =*/ NULL,
@ -14137,34 +14125,34 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
struct ggml_context * ggml_ctx = ggml_init(iparams);
std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<size_t, GGML_MAX_SRC> src_size = {};
std::array<void *, GGML_MAX_SRC> src_buffer = {};
const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
std::map<ggml_tensor *, ggml_tensor *> cloned_tensors;
std::vector<void *> cloned_mallocs;
struct ggml_tensor * tensor_clone = nullptr;
for (int f = 0; f < ctx->num_additional_fused_ops + 1; ++f) {
tensor = cgraph->nodes[tensor_idx + f];
for (int i = 0; i < GGML_MAX_SRC; i++) {
ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
switch (i) {
case 0: srci = rms_norm->src[0]; break;
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
default: continue;
}
}
if (srci == nullptr) {
continue;
}
// If a src tensor has been cloned, use that one
auto it = cloned_tensors.find(srci);
if (it != cloned_tensors.end()) {
src_clone[i] = it->second;
continue;
}
ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
size_t srci_size = ggml_nbytes(srci);
src_clone[i] = srci_clone;
src_size[i] = ggml_nbytes(srci);
src_buffer[i] = malloc(srci_size);
void *src_buffer = malloc(srci_size);
cloned_mallocs.push_back(src_buffer);
srci_clone->data = src_buffer[i];
srci_clone->data = src_buffer;
if (ggml_backend_buffer_is_host(srci->buffer)) {
memcpy(srci_clone->data, srci->data, srci_size);
memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
@ -14214,12 +14202,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_SUB) {
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_MUL) {
if (fused_rms_norm_mul) {
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
} else {
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
}
} else if (tensor->op == GGML_OP_DIV) {
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_CONCAT) {
@ -14267,7 +14250,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const float eps = ((float *) tensor->op_params)[0];
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
} else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) {
if (tensor->src[1] != nullptr) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
} else {
@ -14347,7 +14330,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));
ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
if (src1 == nullptr) {
if (tensor->src[1] == nullptr) {
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
tensor_clone->type = tensor->type;
} else {
@ -14428,6 +14411,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t p0 = tensor->op_params[2];
const int32_t p1 = tensor->op_params[3];
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d_dw_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
const int32_t s = tensor->op_params[0];
tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
@ -14441,11 +14432,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = src0->flags;
src_clone[0]->flags = tensor->src[0]->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4]);
} else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
src_clone[0]->flags = src0->flags;
src_clone[0]->flags = tensor->src[0]->flags;
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2]);
} else if (tensor->op == GGML_OP_ADD_ID) {
@ -14455,11 +14446,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_SSM_CONV) {
tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ROLL) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t s2 = tensor->op_params[2];
const int32_t s3 = tensor->op_params[3];
tensor_clone = ggml_roll(ggml_ctx, src_clone[0], s0, s1, s2, s3);
}
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");
}
cloned_tensors[tensor] = tensor_clone;
}
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
@ -14476,10 +14475,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
memcpy(comp_result, tensor_clone->data, comp_size);
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (src_buffer[i] != nullptr) {
free(src_buffer[i]);
}
for (auto m : cloned_mallocs) {
free(m);
}
ggml_free(ggml_ctx);
@ -14488,15 +14485,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
}
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
return;
}
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
tensor = cgraph->nodes[tensor_idx + 1];
}
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return;