From d35b766819c0637446102989ef2da9f96352c170 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Fri, 3 Apr 2026 17:41:57 +0200 Subject: [PATCH] Add missing op parameters to the profiler; add support for test-backend-ops to run performance tests with exactly the tensor shapes from the run --- ggml/include/ggml-cpu.h | 8 +- ggml/include/ggml-profiler.h | 4 + ggml/src/ggml-backend.cpp | 6 + ggml/src/ggml-blas/ggml-blas.cpp | 10 + ggml/src/ggml-cpu/ggml-cpu.c | 11 +- ggml/src/ggml-cpu/ggml-cpu.cpp | 28 +- ggml/src/ggml-cuda/ggml-cuda.cu | 19 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 27 ++ tests/test-backend-ops.cpp | 546 ++++++++++++++++++++++++++- tools/profiler/profiler.py | 70 +++- 10 files changed, 695 insertions(+), 34 deletions(-) diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index cbbac0b663..a9a7cc6801 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -28,7 +28,7 @@ extern "C" { void * profiling_context; // callback for recording a profile record from C code (set by backend when profiling) - // params: context, type, name, split_id, start_ns, end_ns, bytes, extra, ne_src0[4], ne_src1[4], ne_src2[4] + // params: context, type, name, split_id, start_ns, end_ns, bytes, extra, ne_src0[4], ne_src1[4], ne_src2[4], type_src0, type_src1, type_src2, sub_op void (*profiling_record_fn)(void * context, int type, const char * name, @@ -39,7 +39,11 @@ extern "C" { const char * extra, const int64_t ne_src0[4], const int64_t ne_src1[4], - const int64_t ne_src2[4]); + const int64_t ne_src2[4], + int type_src0, + int type_src1, + int type_src2, + int sub_op); }; // numa strategies diff --git a/ggml/include/ggml-profiler.h b/ggml/include/ggml-profiler.h index 2328f6b49f..46820e75a1 100644 --- a/ggml/include/ggml-profiler.h +++ b/ggml/include/ggml-profiler.h @@ -30,6 +30,10 @@ typedef struct ggml_profile_record { int64_t ne_src0[4]; // src[0] tensor dimensions (e.g. weight matrix for MUL_MAT) int64_t ne_src1[4]; // src[1] tensor dimensions (e.g. input matrix for MUL_MAT) int64_t ne_src2[4]; // src[2] tensor dimensions (e.g. ids for MUL_MAT_ID) + int type_src0; // src[0] tensor type (ggml_type), -1 if N/A + int type_src1; // src[1] tensor type (ggml_type), -1 if N/A + int type_src2; // src[2] tensor type (ggml_type), -1 if N/A + int sub_op; // sub-operation (ggml_unary_op or ggml_glu_op), -1 if N/A } ggml_profile_record; // Backend profiler interface - each backend optionally implements this diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index fa955b68bb..71ea153562 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -2670,6 +2670,12 @@ int ggml_backend_sched_write_profiling_json(ggml_backend_sched_t sched, FILE * f fprintf(fp, ", \"ne_src2\": [%lld, %lld, %lld, %lld]", (long long) rec.ne_src2[0], (long long) rec.ne_src2[1], (long long) rec.ne_src2[2], (long long) rec.ne_src2[3]); + // Tensor types (quantization) + fprintf(fp, ", \"type_src0\": %d", rec.type_src0); + fprintf(fp, ", \"type_src1\": %d", rec.type_src1); + fprintf(fp, ", \"type_src2\": %d", rec.type_src2); + fprintf(fp, ", \"sub_op\": %d", rec.sub_op); + fprintf(fp, "}%s\n", (i < (int) sched->profiling_records.size() - 1) ? "," : ""); } diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index b68f45452e..4c75a38a49 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -274,6 +274,16 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, rec.end_ns = t_end; rec.bytes = ggml_nbytes(node); rec.extra = NULL; + rec.type_src0 = node->src[0] ? (int)node->src[0]->type : -1; + rec.type_src1 = node->src[1] ? (int)node->src[1]->type : -1; + rec.type_src2 = (node->op == GGML_OP_MUL_MAT_ID && node->src[2]) ? (int)node->src[2]->type : -1; + int sub_op = -1; + if (node->op == GGML_OP_UNARY) { + sub_op = (int)ggml_get_unary_op(node); + } else if (node->op == GGML_OP_GLU) { + sub_op = (int)ggml_get_glu_op(node); + } + rec.sub_op = sub_op; if (node->src[0]) { memcpy(rec.ne_src0, node->src[0]->ne, sizeof(rec.ne_src0)); } else { memset(rec.ne_src0, 0, sizeof(rec.ne_src0)); } if (node->src[1]) { memcpy(rec.ne_src1, node->src[1]->ne, sizeof(rec.ne_src1)); } else { memset(rec.ne_src1, 0, sizeof(rec.ne_src1)); } if (node->op == GGML_OP_MUL_MAT_ID && node->src[2]) { memcpy(rec.ne_src2, node->src[2]->ne, sizeof(rec.ne_src2)); } else { memset(rec.ne_src2, 0, sizeof(rec.ne_src2)); } diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index d0fc4228ff..4c6c9abfd3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -3010,9 +3010,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const int64_t * src0_ne = node->src[0] ? node->src[0]->ne : zero_ne; const int64_t * src1_ne = node->src[1] ? node->src[1]->ne : zero_ne; const int64_t * src2_ne = (node->op == GGML_OP_MUL_MAT_ID && node->src[2]) ? node->src[2]->ne : zero_ne; + int type_src0 = node->src[0] ? (int)node->src[0]->type : -1; + int type_src1 = node->src[1] ? (int)node->src[1]->type : -1; + int type_src2 = (node->op == GGML_OP_MUL_MAT_ID && node->src[2]) ? (int)node->src[2]->type : -1; + int sub_op = -1; + if (node->op == GGML_OP_UNARY) { + sub_op = (int)ggml_get_unary_op(node); + } else if (node->op == GGML_OP_GLU) { + sub_op = (int)ggml_get_glu_op(node); + } cplan->profiling_record_fn(cplan->profiling_context, 0 /* GGML_PROFILE_EVENT_OP */, ggml_op_name(node->op), -1, t_start, t_end, ggml_nbytes(node), NULL, - src0_ne, src1_ne, src2_ne); + src0_ne, src1_ne, src2_ne, type_src0, type_src1, type_src2, sub_op); } } diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 6fb9e1987c..c333e01e09 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -175,16 +175,20 @@ static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backe // Callback function for recording CPU profiling events from C code (ggml-cpu.c) static void ggml_cpu_profiler_record_callback(void * context, - int type, - const char * name, - int split_id, - uint64_t start_ns, - uint64_t end_ns, - uint64_t bytes, - const char * extra, - const int64_t ne_src0[4], - const int64_t ne_src1[4], - const int64_t ne_src2[4]) { + int type, + const char * name, + int split_id, + uint64_t start_ns, + uint64_t end_ns, + uint64_t bytes, + const char * extra, + const int64_t ne_src0[4], + const int64_t ne_src1[4], + const int64_t ne_src2[4], + int type_src0, + int type_src1, + int type_src2, + int sub_op) { auto * cpu_ctx = (ggml_backend_cpu_context *) context; ggml_profile_record rec; rec.type = (enum ggml_profile_event_type) type; @@ -195,6 +199,10 @@ static void ggml_cpu_profiler_record_callback(void * context, rec.end_ns = end_ns; rec.bytes = bytes; rec.extra = extra; + rec.type_src0 = type_src0; + rec.type_src1 = type_src1; + rec.type_src2 = type_src2; + rec.sub_op = sub_op; if (ne_src0) { memcpy(rec.ne_src0, ne_src0, sizeof(rec.ne_src0)); } else { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 85d1941dca..3355b666ae 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -138,7 +138,8 @@ struct ggml_cuda_profiler_state { } void record_end(const char * name, int backend_id, int split_id, uint64_t bytes, const char * extra, - const int64_t ne_src0[4], const int64_t ne_src1[4], const int64_t ne_src2[4]) { + const int64_t ne_src0[4], const int64_t ne_src1[4], const int64_t ne_src2[4], + int type_src0, int type_src1, int type_src2, int sub_op = -1) { cudaEvent_t ev; (void) cudaEventCreate(&ev); (void) cudaEventRecord(ev, stream); @@ -154,6 +155,10 @@ struct ggml_cuda_profiler_state { rec.end_ns = 0; rec.bytes = bytes; rec.extra = extra; + rec.type_src0 = type_src0; + rec.type_src1 = type_src1; + rec.type_src2 = type_src2; + rec.sub_op = sub_op; if (ne_src0) { memcpy(rec.ne_src0, ne_src0, sizeof(rec.ne_src0)); } else { memset(rec.ne_src0, 0, sizeof(rec.ne_src0)); } if (ne_src1) { memcpy(rec.ne_src1, ne_src1, sizeof(rec.ne_src1)); } else { memset(rec.ne_src1, 0, sizeof(rec.ne_src1)); } if (ne_src2) { memcpy(rec.ne_src2, ne_src2, sizeof(rec.ne_src2)); } else { memset(rec.ne_src2, 0, sizeof(rec.ne_src2)); } @@ -4131,6 +4136,12 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); if (cuda_ctx->profiler_state != nullptr && cuda_ctx->profiler_state->enabled) { + int sub_op = -1; + if (node->op == GGML_OP_UNARY) { + sub_op = (int)ggml_get_unary_op(node); + } else if (node->op == GGML_OP_GLU) { + sub_op = (int)ggml_get_glu_op(node); + } cuda_ctx->profiler_state->record_end( ggml_op_name(node->op), -1, @@ -4139,7 +4150,11 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud nullptr, node->src[0] ? node->src[0]->ne : nullptr, node->src[1] ? node->src[1]->ne : nullptr, - (node->op == GGML_OP_MUL_MAT_ID && node->src[2]) ? node->src[2]->ne : nullptr + (node->op == GGML_OP_MUL_MAT_ID && node->src[2]) ? node->src[2]->ne : nullptr, + node->src[0] ? (int)node->src[0]->type : -1, + node->src[1] ? (int)node->src[1]->type : -1, + (node->op == GGML_OP_MUL_MAT_ID && node->src[2]) ? (int)node->src[2]->type : -1, + sub_op ); } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 499e384b9d..e2f56a8ad4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -14695,6 +14695,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg rec.end_ns = cpu_ts + duration_ns; rec.bytes = ggml_nbytes(node); rec.extra = name; // fusion name or NULL + rec.type_src0 = node->src[0] ? (int)node->src[0]->type : -1; + rec.type_src1 = node->src[1] ? (int)node->src[1]->type : -1; + rec.type_src2 = (node->op == GGML_OP_MUL_MAT_ID && node->src[2]) ? (int)node->src[2]->type : -1; + { + int sub_op = -1; + if (node->op == GGML_OP_UNARY) { + sub_op = (int)ggml_get_unary_op(node); + } else if (node->op == GGML_OP_GLU) { + sub_op = (int)ggml_get_glu_op(node); + } + rec.sub_op = sub_op; + } memcpy(rec.ne_src0, src0_ne, sizeof(rec.ne_src0)); memcpy(rec.ne_src1, src1_ne, sizeof(rec.ne_src1)); memcpy(rec.ne_src2, src2_ne, sizeof(rec.ne_src2)); @@ -14747,6 +14759,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg rec.end_ns = cpu_ts + duration_ns; rec.bytes = total_bytes; rec.extra = names[0]; // fusion name of first op, or NULL + rec.type_src0 = node->src[0] ? (int)node->src[0]->type : -1; + rec.type_src1 = node->src[1] ? (int)node->src[1]->type : -1; + rec.type_src2 = (node->op == GGML_OP_MUL_MAT_ID && node->src[2]) ? (int)node->src[2]->type : -1; + { + int sub_op = -1; + if (node->op == GGML_OP_UNARY) { + sub_op = (int)ggml_get_unary_op(node); + } else if (node->op == GGML_OP_GLU) { + sub_op = (int)ggml_get_glu_op(node); + } + rec.sub_op = sub_op; + } memcpy(rec.ne_src0, src0_ne, sizeof(rec.ne_src0)); memcpy(rec.ne_src1, src1_ne, sizeof(rec.ne_src1)); memcpy(rec.ne_src2, src2_ne, sizeof(rec.ne_src2)); @@ -14754,6 +14778,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } } } + } + } + } if (ctx->perf_logger) { ctx->perf_logger->print_timings(); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 781c621d93..9ed14763f8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -20,6 +20,8 @@ #include #include +#include + #include #include #include @@ -9059,8 +9061,14 @@ static std::vector> make_test_cases_from_file(const c return test_cases; } +struct profile_test_plan; + +static profile_test_plan make_test_plan_from_profile( + const char * profile_path, int top_n); + static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter, - printer * output_printer, const char * test_file_path) { + printer * output_printer, const char * test_file_path, + std::vector> profile_test_cases = {}) { auto filter_test_cases = [](std::vector> & test_cases, const char * params_filter) { if (params_filter == nullptr) { return; @@ -9081,15 +9089,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op std::vector> test_cases; if (test_file_path == nullptr) { - switch (mode) { - case MODE_TEST: - case MODE_GRAD: - case MODE_SUPPORT: - test_cases = make_test_cases_eval(); - break; - case MODE_PERF: - test_cases = make_test_cases_perf(); - break; + if (!profile_test_cases.empty()) { + test_cases = std::move(profile_test_cases); + } else { + switch (mode) { + case MODE_TEST: + case MODE_GRAD: + case MODE_SUPPORT: + test_cases = make_test_cases_eval(); + break; + case MODE_PERF: + test_cases = make_test_cases_perf(); + break; + } } } else { test_cases = make_test_cases_from_file(test_file_path); @@ -9275,7 +9287,7 @@ static void show_test_coverage() { static void usage(char ** argv) { printf("Usage: %s [mode] [-o ] [-b ] [-p ] [--output ] [--list-ops]", argv[0]); - printf(" [--show-coverage] [--test-file ]\n"); + printf(" [--show-coverage] [--test-file ] [--from-profile ] [--top-n ]\n"); printf(" valid modes:\n"); printf(" - test (default, compare with CPU backend for correctness)\n"); printf(" - grad (compare gradients from backpropagation with method of finite differences)\n"); @@ -9289,6 +9301,414 @@ static void usage(char ** argv) { printf(" --test-file reads test operators from a test file generated by llama-export-graph-ops\n"); } +// ############################## +// ## Profiler-based perf ## +// ############################## + +static ggml_op profile_name_to_op(const std::string & name) { + static const std::unordered_map map = { + {"ADD", GGML_OP_ADD}, + {"ADD1", GGML_OP_ADD1}, + {"ARGSORT", GGML_OP_ARGSORT}, + {"CLAMP", GGML_OP_CLAMP}, + {"CONCAT", GGML_OP_CONCAT}, + {"CONT", GGML_OP_CONT}, + {"CPY", GGML_OP_CPY}, + {"DIV", GGML_OP_DIV}, + {"FLASH_ATTN_EXT", GGML_OP_FLASH_ATTN_EXT}, + {"GET_ROWS", GGML_OP_GET_ROWS}, + {"GET_ROWS_BACK", GGML_OP_GET_ROWS_BACK}, + {"GLU", GGML_OP_GLU}, + {"IM2COL_BACK", GGML_OP_IM2COL_BACK}, + {"MUL", GGML_OP_MUL}, + {"MUL_MAT", GGML_OP_MUL_MAT}, + {"MUL_MAT_ID", GGML_OP_MUL_MAT_ID}, + {"OUT_PROD", GGML_OP_OUT_PROD}, + {"POOL_2D", GGML_OP_POOL_2D}, + {"RMS_NORM", GGML_OP_RMS_NORM}, + {"SCALE", GGML_OP_SCALE}, + {"SET_ROWS", GGML_OP_SET_ROWS}, + {"SQR", GGML_OP_SQR}, + {"SSM_CONV", GGML_OP_SSM_CONV}, + {"SSM_SCAN", GGML_OP_SSM_SCAN}, + {"SUM_ROWS", GGML_OP_SUM_ROWS}, + {"UNARY", GGML_OP_UNARY}, + {"SOFT_MAX", GGML_OP_SOFT_MAX}, + }; + auto it = map.find(name); + if (it != map.end()) { + return it->second; + } + return GGML_OP_COUNT; +} + +static bool compute_output_ne(ggml_op op, + const int64_t ne0[4], const int64_t ne1[4], const int64_t ne2[4], + int64_t ne_out[4]) { + ne_out[0] = ne_out[1] = ne_out[2] = ne_out[3] = 0; + switch (op) { + case GGML_OP_MUL_MAT: + ne_out[0] = ne0[1]; ne_out[1] = ne1[1]; + ne_out[2] = std::max(ne0[2], ne1[2]); + ne_out[3] = std::max(ne0[3], ne1[3]); + return true; + case GGML_OP_MUL_MAT_ID: + ne_out[0] = ne0[1]; ne_out[1] = ne2[0]; + ne_out[2] = ne1[2]; ne_out[3] = 1; + return true; + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SCALE: + for (int i = 0; i < 4; i++) { + ne_out[i] = std::max(ne0[i], ne1[i]); + } + return true; + case GGML_OP_ADD1: + ne_out[0] = ne0[0]; ne_out[1] = ne0[1]; + ne_out[2] = ne0[2]; ne_out[3] = ne0[3]; + return true; + case GGML_OP_SQR: + case GGML_OP_UNARY: + case GGML_OP_SSM_SCAN: + for (int i = 0; i < 4; i++) { + ne_out[i] = ne0[i]; + } + return true; + case GGML_OP_SOFT_MAX: + ne_out[0] = ne0[0]; ne_out[1] = ne0[1]; + ne_out[2] = ne0[2]; ne_out[3] = ne0[3]; + return true; + case GGML_OP_RMS_NORM: + ne_out[0] = ne0[0]; ne_out[1] = ne0[1]; + ne_out[2] = ne0[2]; ne_out[3] = ne0[3]; + return true; + case GGML_OP_FLASH_ATTN_EXT: + ne_out[0] = ne1[1]; ne_out[1] = ne1[1]; + ne_out[2] = ne0[2]; ne_out[3] = ne0[3]; + return true; + case GGML_OP_GET_ROWS: + ne_out[0] = ne0[0]; ne_out[1] = ne1[1]; + ne_out[2] = ne1[2]; ne_out[3] = ne1[3]; + return true; + case GGML_OP_GET_ROWS_BACK: + ne_out[0] = ne0[0]; ne_out[1] = ne1[1]; + ne_out[2] = ne1[2]; ne_out[3] = ne1[3]; + return true; + case GGML_OP_SET_ROWS: + for (int i = 0; i < 4; i++) { + ne_out[i] = ne0[i]; + } + return true; + case GGML_OP_OUT_PROD: + ne_out[0] = ne0[0]; ne_out[1] = ne1[0]; + ne_out[2] = std::max(ne0[2], ne1[2]); + ne_out[3] = std::max(ne0[3], ne1[3]); + return true; + case GGML_OP_CONCAT: + ne_out[0] = ne0[0] + ne1[0]; + ne_out[1] = std::max(ne0[1], ne1[1]); + ne_out[2] = std::max(ne0[2], ne1[2]); + ne_out[3] = std::max(ne0[3], ne1[3]); + return true; + case GGML_OP_ARGSORT: + for (int i = 0; i < 4; i++) { + ne_out[i] = ne0[i]; + } + return true; + case GGML_OP_CLAMP: + for (int i = 0; i < 4; i++) { + ne_out[i] = ne0[i]; + } + return true; + case GGML_OP_CPY: + for (int i = 0; i < 4; i++) { + ne_out[i] = ne0[i]; + } + return true; + case GGML_OP_POOL_2D: + for (int i = 0; i < 4; i++) { + ne_out[i] = ne0[i]; + } + return true; + case GGML_OP_SSM_CONV: + for (int i = 0; i < 4; i++) { + ne_out[i] = ne0[i]; + } + return true; + case GGML_OP_IM2COL_BACK: + for (int i = 0; i < 4; i++) { + ne_out[i] = ne0[i]; + } + return true; + default: + return false; + } +} + +static std::vector json_get_ne(const nlohmann::json & arr) { + std::vector ne(4, 0); + if (!arr.is_array()) return ne; + for (size_t i = 0; i < std::min(arr.size(), (size_t)4); i++) { + ne[i] = arr[i].get(); + } + return ne; +} + +static bool ne_is_zero(const std::vector & ne) { + for (auto v : ne) if (v != 0) return false; + return true; +} + +struct profile_op_key { + std::string name; + int backend_id; + int type_src0; + int type_src1; + int type_src2; + int sub_op; + std::vector ne_src0; + std::vector ne_src1; + std::vector ne_src2; + + bool operator==(const profile_op_key & o) const { + return name == o.name && backend_id == o.backend_id && + type_src0 == o.type_src0 && type_src1 == o.type_src1 && type_src2 == o.type_src2 && + sub_op == o.sub_op && + ne_src0 == o.ne_src0 && ne_src1 == o.ne_src1 && ne_src2 == o.ne_src2; + } +}; + +struct profile_op_key_hash { + size_t operator()(const profile_op_key & k) const { + size_t h = std::hash{}(k.name); + h ^= std::hash{}(k.backend_id) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(k.type_src0) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(k.type_src1) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(k.type_src2) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(k.sub_op) + 0x9e3779b9 + (h << 6) + (h >> 2); + for (auto v : k.ne_src0) { h ^= std::hash{}(v) + 0x9e3779b9 + (h << 6) + (h >> 2); } + for (auto v : k.ne_src1) { h ^= std::hash{}(v) + 0x9e3779b9 + (h << 6) + (h >> 2); } + for (auto v : k.ne_src2) { h ^= std::hash{}(v) + 0x9e3779b9 + (h << 6) + (h >> 2); } + return h; + } +}; + +struct profile_op_agg { + profile_op_key key; + uint64_t total_ns; + int64_t count; + double max_ns; +}; + +struct profile_test_plan { + struct backend_plan { + int backend_id; + std::string backend_name; + std::vector> test_cases; + std::vector aggs; + }; + std::vector backends; +}; + +static profile_test_plan make_test_plan_from_profile( + const char * profile_path, int top_n) { + using json = nlohmann::json; + profile_test_plan plan; + + std::ifstream f(profile_path); + if (!f.is_open()) { + fprintf(stderr, "Error: cannot open profile file: %s\n", profile_path); + return plan; + } + + json root; + try { + root = json::parse(f); + } catch (const json::parse_error & e) { + fprintf(stderr, "Error: failed to parse profile JSON: %s\n", e.what()); + return plan; + } + + if (!root.contains("records") || !root["records"].is_array()) { + fprintf(stderr, "Error: no 'records' array found in profile\n"); + return plan; + } + + std::unordered_map backend_names; + if (root.contains("backends") && root["backends"].is_array()) { + for (const auto & be : root["backends"]) { + int id = be.value("id", -1); + std::string name = be.value("name", ""); + if (id >= 0 && !name.empty()) { + backend_names[id] = name; + } + } + } + + const auto & records = root["records"]; + std::unordered_map aggs; + + for (const auto & rec : records) { + int rec_type = rec.value("type", -1); + if (rec_type != 0) continue; + + std::string name = rec.value("name", ""); + ggml_op op = profile_name_to_op(name); + if (op == GGML_OP_COUNT) continue; + + profile_op_key key; + key.name = name; + key.backend_id = rec.value("backend_id", 0); + key.type_src0 = rec.value("type_src0", -1); + key.type_src1 = rec.value("type_src1", -1); + key.type_src2 = rec.value("type_src2", -1); + key.sub_op = rec.value("sub_op", -1); + key.ne_src0 = json_get_ne(rec.value("ne_src0", json::array())); + key.ne_src1 = json_get_ne(rec.value("ne_src1", json::array())); + key.ne_src2 = json_get_ne(rec.value("ne_src2", json::array())); + + uint64_t dur = rec.value("duration_ns", (uint64_t)0); + + auto & agg = aggs[key]; + agg.key = key; + agg.total_ns += dur; + agg.count++; + agg.max_ns = std::max(agg.max_ns, (double)dur); + } + + std::vector sorted; + sorted.reserve(aggs.size()); + for (auto & [_, agg] : aggs) { + sorted.push_back(std::move(agg)); + } + std::sort(sorted.begin(), sorted.end(), + [](const profile_op_agg & a, const profile_op_agg & b) { + return a.total_ns > b.total_ns; + }); + + uint64_t global_max_ns = sorted.empty() ? 1 : sorted[0].total_ns; + + int n = std::min(top_n, (int)sorted.size()); + if (n == 0) { + fprintf(stderr, "Warning: no matching OP records found in profile\n"); + return plan; + } + + auto make_src = [](int type_id, const std::vector & ne) -> input_tensor { + input_tensor src; + src.type = (type_id >= 0) ? (ggml_type)type_id : GGML_TYPE_F32; + for (int d = 0; d < 4; d++) { + src.ne[d] = d < (int)ne.size() ? ne[d] : 0; + src.nb[d] = 0; + } + return src; + }; + + auto make_test_from_agg = [&](const profile_op_agg & agg) -> std::unique_ptr { + ggml_op op = profile_name_to_op(agg.key.name); + + std::vector sources; + if (!ne_is_zero(agg.key.ne_src0)) { + sources.push_back(make_src(agg.key.type_src0, agg.key.ne_src0)); + } + if (!ne_is_zero(agg.key.ne_src1)) { + sources.push_back(make_src(agg.key.type_src1, agg.key.ne_src1)); + } + if (op == GGML_OP_MUL_MAT_ID) { + if (!ne_is_zero(agg.key.ne_src2)) { + sources.push_back(make_src(agg.key.type_src2, agg.key.ne_src2)); + } else if (sources.size() >= 2) { + input_tensor src; + src.type = GGML_TYPE_I32; + src.ne[0] = sources[1].ne[1]; + src.ne[1] = 1; + src.ne[2] = 1; + src.ne[3] = 1; + src.nb[0] = src.nb[1] = src.nb[2] = src.nb[3] = 0; + sources.push_back(src); + } + } + + int64_t ne0[4] = {0}, ne1[4] = {0}, ne2[4] = {0}; + if (sources.size() > 0) { for (int d = 0; d < 4; d++) ne0[d] = sources[0].ne[d]; } + if (sources.size() > 1) { for (int d = 0; d < 4; d++) ne1[d] = sources[1].ne[d]; } + if (sources.size() > 2) { for (int d = 0; d < 4; d++) ne2[d] = sources[2].ne[d]; } + + int64_t ne_out[4] = {0, 0, 0, 0}; + if (!compute_output_ne(op, ne0, ne1, ne2, ne_out)) { + return nullptr; + } + + ggml_type out_type = GGML_TYPE_F32; + std::array op_params{}; + op_params.fill(0); + + if (op == GGML_OP_MUL_MAT_ID && sources.size() >= 2) { + op_params[0] = (int32_t)sources[1].ne[1]; + } else if ((op == GGML_OP_UNARY || op == GGML_OP_GLU) && agg.key.sub_op >= 0) { + op_params[0] = (int32_t)agg.key.sub_op; + } + + std::array out_ne; + for (int d = 0; d < 4; d++) out_ne[d] = ne_out[d]; + + return std::unique_ptr(new test_generic_op(op, out_type, out_ne, op_params, sources, + agg.key.name + " [from profile]")); + }; + + printf(" Loaded %d profiler ops, running top %d:\n", (int)sorted.size(), n); + for (int i = 0; i < n; i++) { + const auto & agg = sorted[i]; + double pct = 100.0 * agg.total_ns / global_max_ns; + int bid = agg.key.backend_id; + std::string bname = backend_names.count(bid) ? backend_names[bid] : std::to_string(bid); + printf(" #%d: %s @ %s %ldx %.3fms total (%.1f%% of top)\n", + i + 1, agg.key.name.c_str(), bname.c_str(), agg.count, + agg.total_ns / 1e6, pct); + if (!ne_is_zero(agg.key.ne_src0)) { + const char * tn = agg.key.type_src0 >= 0 ? ggml_type_name((ggml_type)agg.key.type_src0) : "?"; + printf(" src0: [%lld, %lld, %lld, %lld] (%s)\n", + (long long)agg.key.ne_src0[0], (long long)agg.key.ne_src0[1], + (long long)agg.key.ne_src0[2], (long long)agg.key.ne_src0[3], tn); + } + if (!ne_is_zero(agg.key.ne_src1)) { + const char * tn = agg.key.type_src1 >= 0 ? ggml_type_name((ggml_type)agg.key.type_src1) : "?"; + printf(" src1: [%lld, %lld, %lld, %lld] (%s)\n", + (long long)agg.key.ne_src1[0], (long long)agg.key.ne_src1[1], + (long long)agg.key.ne_src1[2], (long long)agg.key.ne_src1[3], tn); + } + if (!ne_is_zero(agg.key.ne_src2)) { + const char * tn = agg.key.type_src2 >= 0 ? ggml_type_name((ggml_type)agg.key.type_src2) : "?"; + printf(" src2: [%lld, %lld, %lld, %lld] (%s)\n", + (long long)agg.key.ne_src2[0], (long long)agg.key.ne_src2[1], + (long long)agg.key.ne_src2[2], (long long)agg.key.ne_src2[3], tn); + } + + auto tc = make_test_from_agg(agg); + if (!tc) continue; + + int bid2 = agg.key.backend_id; + std::string bname2 = backend_names.count(bid2) ? backend_names[bid2] : std::to_string(bid2); + + profile_test_plan::backend_plan * bp = nullptr; + for (auto & b : plan.backends) { + if (b.backend_id == bid2 && b.backend_name == bname2) { + bp = &b; + break; + } + } + if (!bp) { + plan.backends.push_back({bid2, bname2, {}, {}}); + bp = &plan.backends.back(); + } + bp->test_cases.push_back(std::move(tc)); + bp->aggs.push_back(agg); + } + + return plan; +} + int main(int argc, char ** argv) { test_mode mode = MODE_TEST; output_formats output_format = CONSOLE; @@ -9296,6 +9716,8 @@ int main(int argc, char ** argv) { const char * backend_filter = nullptr; const char * params_filter = nullptr; const char * test_file_path = nullptr; + const char * profile_path = nullptr; + int profile_top_n = 10; for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "test") == 0) { @@ -9350,6 +9772,21 @@ int main(int argc, char ** argv) { usage(argv); return 1; } + } else if (strcmp(argv[i], "--from-profile") == 0) { + if (i + 1 < argc) { + profile_path = argv[++i]; + } else { + usage(argv); + return 1; + } + } else if (strcmp(argv[i], "--top-n") == 0) { + if (i + 1 < argc) { + profile_top_n = atoi(argv[++i]); + if (profile_top_n <= 0) profile_top_n = 10; + } else { + usage(argv); + return 1; + } } else { usage(argv); return 1; @@ -9359,6 +9796,10 @@ int main(int argc, char ** argv) { // load and enumerate backends ggml_backend_load_all(); + if (profile_path != nullptr) { + mode = MODE_PERF; + } + // Create printer for output format std::unique_ptr output_printer = create_printer(output_format); if (output_printer) { @@ -9367,6 +9808,83 @@ int main(int argc, char ** argv) { output_printer->print_testing_start(testing_start_info(ggml_backend_dev_count())); + if (profile_path != nullptr) { + profile_test_plan plan = make_test_plan_from_profile(profile_path, profile_top_n); + + size_t n_ok = 0; + size_t total = plan.backends.size(); + + for (size_t bi = 0; bi < plan.backends.size(); bi++) { + auto & bp = plan.backends[bi]; + + ggml_backend_dev_t dev = nullptr; + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t d = ggml_backend_dev_get(i); + if (strcmp(ggml_backend_dev_name(d), bp.backend_name.c_str()) == 0) { + dev = d; + break; + } + } + + if (dev == nullptr) { + fprintf(stderr, "Warning: backend '%s' from profile not found, skipping\n", bp.backend_name.c_str()); + n_ok++; + output_printer->print_backend_init( + backend_init_info(bi, total, bp.backend_name.c_str(), true, "Not found")); + continue; + } + + if (backend_filter != NULL && strcmp(backend_filter, bp.backend_name.c_str()) != 0) { + output_printer->print_backend_init( + backend_init_info(bi, total, bp.backend_name.c_str(), true, "Skipping")); + n_ok++; + continue; + } + + ggml_backend_t backend = ggml_backend_dev_init(dev, NULL); + GGML_ASSERT(backend != NULL); + + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend, N_THREADS); + } + + size_t free, total_mem; + ggml_backend_dev_memory(dev, &free, &total_mem); + output_printer->print_backend_init(backend_init_info(bi, plan.backends.size(), bp.backend_name.c_str(), + false, "", ggml_backend_dev_description(dev), + total_mem / 1024 / 1024, free / 1024 / 1024, true)); + + std::vector> cases; + for (auto & tc : bp.test_cases) { + cases.push_back(std::move(tc)); + } + + bool ok = test_backend(backend, MODE_PERF, op_names_filter, params_filter, + output_printer.get(), nullptr, std::move(cases)); + + if (ok) { + n_ok++; + } + output_printer->print_backend_status( + backend_status_info(ggml_backend_name(backend), ok ? test_status_t::OK : test_status_t::FAIL)); + + ggml_backend_free(backend); + } + + ggml_quantize_free(); + + if (output_printer) { + output_printer->print_footer(); + } + + output_printer->print_overall_summary( + overall_summary_info(n_ok, total, n_ok == total)); + + return n_ok != total; + } + size_t n_ok = 0; for (size_t i = 0; i < ggml_backend_dev_count(); i++) { @@ -9379,7 +9897,8 @@ int main(int argc, char ** argv) { continue; } - if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) { + if (backend_filter == NULL && + ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) { output_printer->print_backend_init(backend_init_info( i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping CPU backend")); n_ok++; @@ -9392,11 +9911,10 @@ int main(int argc, char ** argv) { ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); if (ggml_backend_set_n_threads_fn) { - // TODO: better value for n_threads ggml_backend_set_n_threads_fn(backend, N_THREADS); } - size_t free, total; // NOLINT + size_t free, total; ggml_backend_dev_memory(dev, &free, &total); output_printer->print_backend_init(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), false, "", ggml_backend_dev_description(dev), diff --git a/tools/profiler/profiler.py b/tools/profiler/profiler.py index a739a830f4..0cc3f81972 100644 --- a/tools/profiler/profiler.py +++ b/tools/profiler/profiler.py @@ -20,6 +20,35 @@ COPY_EVENT = 1 TYPE_NAMES = {0: "OP", 1: "COPY"} +GGML_TYPE_NAMES = { + 0: "F32", 1: "F16", 2: "Q4_0", 3: "Q4_1", + # 4, 5 removed + 6: "Q5_0", 7: "Q5_1", 8: "Q8_0", 9: "Q8_1", + 10: "Q2_K", 11: "Q3_K", 12: "Q4_K", 13: "Q5_K", + 14: "Q6_K", 15: "Q8_K", 16: "IQ2_XXS", 17: "IQ2_XS", + 18: "IQ3_XXS", 19: "IQ1_S", 20: "IQ4_NL", 21: "IQ3_S", + 22: "IQ2_S", 23: "IQ4_XS", 24: "I8", 25: "I16", 26: "I32", + 27: "I64", 28: "F64", 29: "IQ1_M", 30: "BF16", + # 31-33 removed + 34: "TQ1_0", 35: "TQ2_0", + # 36-38 removed + 39: "MXFP4", 40: "NVFP4", +} + +GGML_UNARY_OP_NAMES = { + 0: "ABS", 1: "SGN", 2: "NEG", 3: "STEP", + 4: "TANH", 5: "ELU", 6: "RELU", 7: "SIGMOID", + 8: "GELU", 9: "GELU_QUICK", 10: "SILU", 11: "HARDSWISH", + 12: "HARDSIGMOID", 13: "EXP", 14: "EXPM1", 15: "SOFTPLUS", + 16: "GELU_ERF", 17: "XIELU", 18: "FLOOR", 19: "CEIL", + 20: "ROUND", 21: "TRUNC", +} + +GGML_GLU_OP_NAMES = { + 0: "REGLU", 1: "GEGLU", 2: "SWIGLU", 3: "GEGLU_ERF", + 4: "GEGLU_QUICK", 5: "SWIGLU_OAI", +} + @dataclass class ProfileRecord: @@ -34,6 +63,20 @@ class ProfileRecord: ne_src0: list[int] = field(default_factory=lambda: [0, 0, 0, 0]) ne_src1: list[int] = field(default_factory=lambda: [0, 0, 0, 0]) ne_src2: list[int] = field(default_factory=lambda: [0, 0, 0, 0]) + type_src0: int = -1 + type_src1: int = -1 + type_src2: int = -1 + sub_op: int = -1 + + @property + def sub_op_name(self) -> str: + if self.sub_op < 0: + return "" + if self.name == "UNARY": + return GGML_UNARY_OP_NAMES.get(self.sub_op, f"UNARY_OP({self.sub_op})") + if self.name == "GLU": + return GGML_GLU_OP_NAMES.get(self.sub_op, f"GLU_OP({self.sub_op})") + return str(self.sub_op) @property def type_name(self) -> str: @@ -64,11 +107,20 @@ class ProfileRecord: @property def shape_str(self) -> str: """Human-readable tensor shapes, e.g. '[4096, 4096] x [4096, 1] x [8, 1]'.""" - s0 = self._fmt_ne(self.ne_src0) - s1 = self._fmt_ne(self.ne_src1) - s2 = self._fmt_ne(self.ne_src2) - parts = [s for s in (s0, s1, s2) if s] - return " x ".join(parts) + parts = [] + for ne, gt in [(self.ne_src0, self.type_src0), + (self.ne_src1, self.type_src1), + (self.ne_src2, self.type_src2)]: + s = self._fmt_ne(ne) + if s: + type_name = GGML_TYPE_NAMES.get(gt, None) + if type_name: + s = f"{s} ({type_name})" + parts.append(s) + result = " x ".join(parts) + if self.sub_op_name: + result = f"[{self.sub_op_name}] {result}" + return result def to_dict(self) -> dict: return { @@ -83,6 +135,10 @@ class ProfileRecord: "ne_src0": self.ne_src0, "ne_src1": self.ne_src1, "ne_src2": self.ne_src2, + "type_src0": self.type_src0, + "type_src1": self.type_src1, + "type_src2": self.type_src2, + "sub_op": self.sub_op, } @@ -175,6 +231,10 @@ class ProfileData: ne_src0=ne_src0, ne_src1=ne_src1, ne_src2=ne_src2, + type_src0=r.get("type_src0", -1), + type_src1=r.get("type_src1", -1), + type_src2=r.get("type_src2", -1), + sub_op=r.get("sub_op", -1), )) backends_raw = data.get("backends", [])