diff --git a/common/arg.cpp b/common/arg.cpp index d1d437959b..efcc8d7079 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2693,6 +2693,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_RESULTS, LLAMA_EXAMPLE_EXPORT_GRAPH_OPS})); + add_opt(common_arg( + {"--with-backends"}, + "export graph ops with backend assignments (default: CPU only)", + [](common_params & params) { + params.with_backends = true; + } + ).set_examples({LLAMA_EXAMPLE_EXPORT_GRAPH_OPS})); add_opt(common_arg( {"-ofreq", "--output-frequency"}, "N", string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq), diff --git a/common/common.h b/common/common.h index 20907cd212..c7c7a4a34c 100644 --- a/common/common.h +++ b/common/common.h @@ -438,6 +438,7 @@ struct common_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs bool fit_params = true; // whether to fit unset model/context parameters to free device memory + bool with_backends = false; // export graph ops with backend assignments int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use // margin per device in bytes for fitting parameters to free memory: diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 71ea153562..bdc6612c61 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1526,7 +1526,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s sched->copy_records.push_back({ GGML_PROFILE_EVENT_COPY, copy_dir, split_backend_id, split_id, copy_start, copy_end, ggml_nbytes(input), input->name, - {input->ne[0], input->ne[1], input->ne[2], input->ne[3]}, {0}, {0} }); + {input->ne[0], input->ne[1], input->ne[2], input->ne[3]}, {0}, {0}, -1, -1, -1, -1 }); } else { ggml_backend_tensor_copy(input, input_cpy); } @@ -1647,7 +1647,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s sched->copy_records.push_back({ GGML_PROFILE_EVENT_COPY, copy_dir, split_backend_id, split_id, moe_copy_start, moe_copy_end, (uint64_t) total_copied_bytes, input->name, - {input->ne[0], input->ne[1], input->ne[2], input->ne[3]}, {0}, {0} }); + {input->ne[0], input->ne[1], input->ne[2], input->ne[3]}, {0}, {0}, -1, -1, -1, -1 }); } } else { // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events @@ -1684,7 +1684,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s sched->copy_records.push_back({ GGML_PROFILE_EVENT_COPY, copy_dir, split_backend_id, split_id, copy_start, copy_end, ggml_nbytes(input), input->name, - {input->ne[0], input->ne[1], input->ne[2], input->ne[3]}, {0}, {0} }); + {input->ne[0], input->ne[1], input->ne[2], input->ne[3]}, {0}, {0}, -1, -1, -1, -1 }); } else { ggml_backend_tensor_copy(input, input_cpy); } @@ -1705,7 +1705,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s sched->copy_records.push_back({ GGML_PROFILE_EVENT_COPY, copy_dir, split_backend_id, split_id, copy_start, copy_end, ggml_nbytes(input), input->name, - {input->ne[0], input->ne[1], input->ne[2], input->ne[3]}, {0}, {0} }); + {input->ne[0], input->ne[1], input->ne[2], input->ne[3]}, {0}, {0}, -1, -1, -1, -1 }); } } } diff --git a/tests/export-graph-ops.cpp b/tests/export-graph-ops.cpp index 64cf6dcea3..3628afa325 100644 --- a/tests/export-graph-ops.cpp +++ b/tests/export-graph-ops.cpp @@ -5,7 +5,6 @@ #include "../src/llama-ext.h" #include "ggml.h" #include "gguf-model-data.h" -#include "gguf.h" #include "ggml-backend.h" #include "download.h" @@ -14,7 +13,6 @@ #include #include #include -#include // Noop because weights are not needed static void set_tensor_data(struct ggml_tensor * tensor, void * userdata) { @@ -55,6 +53,7 @@ struct test_object { std::vector op_params; std::vector sources; std::string name; + std::string backend_name; void serialize(std::ostream& out) const { out << op << ' ' << type << ' '; @@ -78,16 +77,21 @@ struct test_object { out << '-'; } + if (!backend_name.empty()) { + out << ' ' << backend_name; + } + out << '\n'; } bool operator<(const test_object &b) const { - return std::tie(op, type, ne, op_params, sources) < - std::tie(b.op, b.type, b.ne, b.op_params, b.sources); + return std::tie(op, type, ne, op_params, sources, backend_name) < + std::tie(b.op, b.type, b.ne, b.op_params, b.sources, b.backend_name); } }; -static void extract_graph_ops(ggml_cgraph * cgraph, const char * label, std::set & tests) { +static void extract_graph_ops(ggml_cgraph * cgraph, const char * label, std::set & tests, + ggml_backend_sched_t sched = nullptr) { int n_nodes = ggml_graph_n_nodes(cgraph); int n_skipped = 0; int n_before = (int) tests.size(); @@ -117,6 +121,14 @@ static void extract_graph_ops(ggml_cgraph * cgraph, const char * label, std::set } test.name = node->name; + + if (sched) { + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, node); + if (backend) { + test.backend_name = ggml_backend_name(backend); + } + } + tests.insert(test); } @@ -135,11 +147,12 @@ int main(int argc, char ** argv) { return 1; } - // Load CPU-only - ggml_backend_dev_t cpu_device = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - params.devices = { cpu_device, nullptr }; - params.fit_params = false; - params.n_gpu_layers = 0; + if (!params.with_backends) { + ggml_backend_dev_t cpu_device = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + params.devices = { cpu_device, nullptr }; + params.fit_params = false; + params.n_gpu_layers = 0; + } params.warmup = false; @@ -195,19 +208,21 @@ int main(int argc, char ** argv) { std::set tests; + ggml_backend_sched_t sched = params.with_backends ? llama_context_get_sched(ctx) : nullptr; + auto * gf_pp = llama_graph_reserve(ctx, n_tokens, n_seqs, n_tokens); if (!gf_pp) { LOG_ERR("failed to reserve prompt processing graph\n"); return 1; } - extract_graph_ops(gf_pp, "pp", tests); + extract_graph_ops(gf_pp, "pp", tests, sched); auto * gf_tg = llama_graph_reserve(ctx, n_seqs, n_seqs, n_seqs); if (!gf_tg) { LOG_ERR("failed to reserve token generation graph\n"); return 1; } - extract_graph_ops(gf_tg, "tg", tests); + extract_graph_ops(gf_tg, "tg", tests, sched); LOG_INF("%d unique ops total\n", (int) tests.size()); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8eca58207d..f013960790 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -20,8 +20,6 @@ #include #include -#include - #include #include #include @@ -8962,7 +8960,7 @@ static std::vector> make_test_cases_perf() { return test_cases; } -static std::vector> make_test_cases_from_file(const char * path) { +static std::vector> make_test_cases_from_file(const char * path, const char * backend_name = nullptr) { std::ifstream f(path); if (!f.is_open()) { @@ -9020,20 +9018,25 @@ static std::vector> make_test_cases_from_file(const c name = ""; } + std::string file_backend; + if (iss >> file_backend) { + if (file_backend.length() == 1 && file_backend[0] == '-') { + file_backend = ""; + } + } + + if (backend_name != nullptr && !file_backend.empty() && file_backend != backend_name) { + continue; + } + test_cases.emplace_back(new test_generic_op(op, type, ne, op_params, sources, std::move(name))); } return test_cases; } -struct profile_test_plan; - -static profile_test_plan make_test_plan_from_profile( - const char * profile_path, int top_n); - static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter, - printer * output_printer, const char * test_file_path, - std::vector> profile_test_cases = {}) { + printer * output_printer, const char * test_file_path) { auto filter_test_cases = [](std::vector> & test_cases, const char * params_filter) { if (params_filter == nullptr) { return; @@ -9054,22 +9057,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op std::vector> test_cases; if (test_file_path == nullptr) { - if (!profile_test_cases.empty()) { - test_cases = std::move(profile_test_cases); - } else { - switch (mode) { - case MODE_TEST: - case MODE_GRAD: - case MODE_SUPPORT: - test_cases = make_test_cases_eval(); - break; - case MODE_PERF: - test_cases = make_test_cases_perf(); - break; - } + switch (mode) { + case MODE_TEST: + case MODE_GRAD: + case MODE_SUPPORT: + test_cases = make_test_cases_eval(); + break; + case MODE_PERF: + test_cases = make_test_cases_perf(); + break; } } else { - test_cases = make_test_cases_from_file(test_file_path); + test_cases = make_test_cases_from_file(test_file_path, ggml_backend_name(backend)); } filter_test_cases(test_cases, params_filter); @@ -9252,7 +9251,7 @@ static void show_test_coverage() { static void usage(char ** argv) { printf("Usage: %s [mode] [-o ] [-b ] [-p ] [--output ] [--list-ops]", argv[0]); - printf(" [--show-coverage] [--test-file ] [--from-profile ] [--top-n ]\n"); + printf(" [--show-coverage] [--test-file ]\n"); printf(" valid modes:\n"); printf(" - test (default, compare with CPU backend for correctness)\n"); printf(" - grad (compare gradients from backpropagation with method of finite differences)\n"); @@ -9263,436 +9262,31 @@ static void usage(char ** argv) { printf(" --output specifies output format (default: console, options: console, sql, csv)\n"); printf(" --list-ops lists all available GGML operations\n"); printf(" --show-coverage shows test coverage\n"); - printf(" --test-file reads test operators from a test file generated by llama-export-graph-ops\n"); -} - -// ############################## -// ## Profiler-based perf ## -// ############################## - -static ggml_op profile_name_to_op(const std::string & name) { - static const std::unordered_map map = { - {"ADD", GGML_OP_ADD}, - {"ADD1", GGML_OP_ADD1}, - {"ARGSORT", GGML_OP_ARGSORT}, - {"CLAMP", GGML_OP_CLAMP}, - {"CONCAT", GGML_OP_CONCAT}, - {"CONT", GGML_OP_CONT}, - {"CPY", GGML_OP_CPY}, - {"DIV", GGML_OP_DIV}, - {"FLASH_ATTN_EXT", GGML_OP_FLASH_ATTN_EXT}, - {"GET_ROWS", GGML_OP_GET_ROWS}, - {"GET_ROWS_BACK", GGML_OP_GET_ROWS_BACK}, - {"GLU", GGML_OP_GLU}, - {"IM2COL_BACK", GGML_OP_IM2COL_BACK}, - {"MUL", GGML_OP_MUL}, - {"MUL_MAT", GGML_OP_MUL_MAT}, - {"MUL_MAT_ID", GGML_OP_MUL_MAT_ID}, - {"OUT_PROD", GGML_OP_OUT_PROD}, - {"POOL_2D", GGML_OP_POOL_2D}, - {"RMS_NORM", GGML_OP_RMS_NORM}, - {"SCALE", GGML_OP_SCALE}, - {"SET_ROWS", GGML_OP_SET_ROWS}, - {"SQR", GGML_OP_SQR}, - {"SSM_CONV", GGML_OP_SSM_CONV}, - {"SSM_SCAN", GGML_OP_SSM_SCAN}, - {"SUM_ROWS", GGML_OP_SUM_ROWS}, - {"UNARY", GGML_OP_UNARY}, - {"SOFT_MAX", GGML_OP_SOFT_MAX}, - }; - auto it = map.find(name); - if (it != map.end()) { - return it->second; - } - return GGML_OP_COUNT; -} - -static bool compute_output_ne(ggml_op op, - const int64_t ne0[4], const int64_t ne1[4], const int64_t ne2[4], - int64_t ne_out[4]) { - ne_out[0] = ne_out[1] = ne_out[2] = ne_out[3] = 0; - switch (op) { - case GGML_OP_MUL_MAT: - ne_out[0] = ne0[1]; ne_out[1] = ne1[1]; - ne_out[2] = std::max(ne0[2], ne1[2]); - ne_out[3] = std::max(ne0[3], ne1[3]); - return true; - case GGML_OP_MUL_MAT_ID: - ne_out[0] = ne0[1]; ne_out[1] = ne2[0]; - ne_out[2] = ne1[2]; ne_out[3] = 1; - return true; - case GGML_OP_ADD: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_SCALE: - for (int i = 0; i < 4; i++) { - ne_out[i] = std::max(ne0[i], ne1[i]); - } - return true; - case GGML_OP_ADD1: - ne_out[0] = ne0[0]; ne_out[1] = ne0[1]; - ne_out[2] = ne0[2]; ne_out[3] = ne0[3]; - return true; - case GGML_OP_SQR: - case GGML_OP_UNARY: - case GGML_OP_SSM_SCAN: - for (int i = 0; i < 4; i++) { - ne_out[i] = ne0[i]; - } - return true; - case GGML_OP_SOFT_MAX: - ne_out[0] = ne0[0]; ne_out[1] = ne0[1]; - ne_out[2] = ne0[2]; ne_out[3] = ne0[3]; - return true; - case GGML_OP_RMS_NORM: - ne_out[0] = ne0[0]; ne_out[1] = ne0[1]; - ne_out[2] = ne0[2]; ne_out[3] = ne0[3]; - return true; - case GGML_OP_FLASH_ATTN_EXT: - ne_out[0] = ne1[1]; ne_out[1] = ne1[1]; - ne_out[2] = ne0[2]; ne_out[3] = ne0[3]; - return true; - case GGML_OP_GET_ROWS: - ne_out[0] = ne0[0]; ne_out[1] = ne1[1]; - ne_out[2] = ne1[2]; ne_out[3] = ne1[3]; - return true; - case GGML_OP_GET_ROWS_BACK: - ne_out[0] = ne0[0]; ne_out[1] = ne1[1]; - ne_out[2] = ne1[2]; ne_out[3] = ne1[3]; - return true; - case GGML_OP_SET_ROWS: - for (int i = 0; i < 4; i++) { - ne_out[i] = ne0[i]; - } - return true; - case GGML_OP_OUT_PROD: - ne_out[0] = ne0[0]; ne_out[1] = ne1[0]; - ne_out[2] = std::max(ne0[2], ne1[2]); - ne_out[3] = std::max(ne0[3], ne1[3]); - return true; - case GGML_OP_CONCAT: - ne_out[0] = ne0[0] + ne1[0]; - ne_out[1] = std::max(ne0[1], ne1[1]); - ne_out[2] = std::max(ne0[2], ne1[2]); - ne_out[3] = std::max(ne0[3], ne1[3]); - return true; - case GGML_OP_ARGSORT: - for (int i = 0; i < 4; i++) { - ne_out[i] = ne0[i]; - } - return true; - case GGML_OP_CLAMP: - for (int i = 0; i < 4; i++) { - ne_out[i] = ne0[i]; - } - return true; - case GGML_OP_CPY: - for (int i = 0; i < 4; i++) { - ne_out[i] = ne0[i]; - } - return true; - case GGML_OP_POOL_2D: - for (int i = 0; i < 4; i++) { - ne_out[i] = ne0[i]; - } - return true; - case GGML_OP_SSM_CONV: - for (int i = 0; i < 4; i++) { - ne_out[i] = ne0[i]; - } - return true; - case GGML_OP_IM2COL_BACK: - for (int i = 0; i < 4; i++) { - ne_out[i] = ne0[i]; - } - return true; - default: - return false; - } -} - -static std::vector json_get_ne(const nlohmann::json & arr) { - std::vector ne(4, 0); - if (!arr.is_array()) return ne; - for (size_t i = 0; i < std::min(arr.size(), (size_t)4); i++) { - ne[i] = arr[i].get(); - } - return ne; -} - -static bool ne_is_zero(const std::vector & ne) { - for (auto v : ne) if (v != 0) return false; - return true; -} - -struct profile_op_key { - std::string name; - int backend_id; - int type_src0; - int type_src1; - int type_src2; - int sub_op; - std::vector ne_src0; - std::vector ne_src1; - std::vector ne_src2; - - bool operator==(const profile_op_key & o) const { - return name == o.name && backend_id == o.backend_id && - type_src0 == o.type_src0 && type_src1 == o.type_src1 && type_src2 == o.type_src2 && - sub_op == o.sub_op && - ne_src0 == o.ne_src0 && ne_src1 == o.ne_src1 && ne_src2 == o.ne_src2; - } -}; - -struct profile_op_key_hash { - size_t operator()(const profile_op_key & k) const { - size_t h = std::hash{}(k.name); - h ^= std::hash{}(k.backend_id) + 0x9e3779b9 + (h << 6) + (h >> 2); - h ^= std::hash{}(k.type_src0) + 0x9e3779b9 + (h << 6) + (h >> 2); - h ^= std::hash{}(k.type_src1) + 0x9e3779b9 + (h << 6) + (h >> 2); - h ^= std::hash{}(k.type_src2) + 0x9e3779b9 + (h << 6) + (h >> 2); - h ^= std::hash{}(k.sub_op) + 0x9e3779b9 + (h << 6) + (h >> 2); - for (auto v : k.ne_src0) { h ^= std::hash{}(v) + 0x9e3779b9 + (h << 6) + (h >> 2); } - for (auto v : k.ne_src1) { h ^= std::hash{}(v) + 0x9e3779b9 + (h << 6) + (h >> 2); } - for (auto v : k.ne_src2) { h ^= std::hash{}(v) + 0x9e3779b9 + (h << 6) + (h >> 2); } - return h; - } -}; - -struct profile_op_agg { - profile_op_key key; - uint64_t total_ns; - int64_t count; - double max_ns; -}; - -struct profile_test_plan { - struct backend_plan { - int backend_id; - std::string backend_name; - std::vector> test_cases; - std::vector aggs; - }; - std::vector backends; -}; - -static profile_test_plan make_test_plan_from_profile( - const char * profile_path, int top_n) { - using json = nlohmann::json; - profile_test_plan plan; - - std::ifstream f(profile_path); - if (!f.is_open()) { - fprintf(stderr, "Error: cannot open profile file: %s\n", profile_path); - return plan; - } - - json root; - try { - root = json::parse(f); - } catch (const json::parse_error & e) { - fprintf(stderr, "Error: failed to parse profile JSON: %s\n", e.what()); - return plan; - } - - if (!root.contains("records") || !root["records"].is_array()) { - fprintf(stderr, "Error: no 'records' array found in profile\n"); - return plan; - } - - std::unordered_map backend_names; - if (root.contains("backends") && root["backends"].is_array()) { - for (const auto & be : root["backends"]) { - int id = be.value("id", -1); - std::string name = be.value("name", ""); - if (id >= 0 && !name.empty()) { - backend_names[id] = name; - } - } - } - - const auto & records = root["records"]; - std::unordered_map aggs; - - for (const auto & rec : records) { - int rec_type = rec.value("type", -1); - if (rec_type != 0) continue; - - std::string name = rec.value("name", ""); - ggml_op op = profile_name_to_op(name); - if (op == GGML_OP_COUNT) continue; - - profile_op_key key; - key.name = name; - key.backend_id = rec.value("backend_id", 0); - key.type_src0 = rec.value("type_src0", -1); - key.type_src1 = rec.value("type_src1", -1); - key.type_src2 = rec.value("type_src2", -1); - key.sub_op = rec.value("sub_op", -1); - key.ne_src0 = json_get_ne(rec.value("ne_src0", json::array())); - key.ne_src1 = json_get_ne(rec.value("ne_src1", json::array())); - key.ne_src2 = json_get_ne(rec.value("ne_src2", json::array())); - - uint64_t dur = rec.value("duration_ns", (uint64_t)0); - - auto & agg = aggs[key]; - agg.key = key; - agg.total_ns += dur; - agg.count++; - agg.max_ns = std::max(agg.max_ns, (double)dur); - } - - std::vector sorted; - sorted.reserve(aggs.size()); - for (auto & [_, agg] : aggs) { - sorted.push_back(std::move(agg)); - } - std::sort(sorted.begin(), sorted.end(), - [](const profile_op_agg & a, const profile_op_agg & b) { - return a.total_ns > b.total_ns; - }); - - uint64_t global_max_ns = sorted.empty() ? 1 : sorted[0].total_ns; - - int n = std::min(top_n, (int)sorted.size()); - if (n == 0) { - fprintf(stderr, "Warning: no matching OP records found in profile\n"); - return plan; - } - - auto make_src = [](int type_id, const std::vector & ne) -> input_tensor { - input_tensor src; - src.type = (type_id >= 0) ? (ggml_type)type_id : GGML_TYPE_F32; - for (int d = 0; d < 4; d++) { - src.ne[d] = d < (int)ne.size() ? ne[d] : 0; - src.nb[d] = 0; - } - return src; - }; - - auto make_test_from_agg = [&](const profile_op_agg & agg) -> std::unique_ptr { - ggml_op op = profile_name_to_op(agg.key.name); - - std::vector sources; - if (!ne_is_zero(agg.key.ne_src0)) { - sources.push_back(make_src(agg.key.type_src0, agg.key.ne_src0)); - } - if (!ne_is_zero(agg.key.ne_src1)) { - sources.push_back(make_src(agg.key.type_src1, agg.key.ne_src1)); - } - if (op == GGML_OP_MUL_MAT_ID) { - if (!ne_is_zero(agg.key.ne_src2)) { - sources.push_back(make_src(agg.key.type_src2, agg.key.ne_src2)); - } else if (sources.size() >= 2) { - input_tensor src; - src.type = GGML_TYPE_I32; - src.ne[0] = sources[1].ne[1]; - src.ne[1] = 1; - src.ne[2] = 1; - src.ne[3] = 1; - src.nb[0] = src.nb[1] = src.nb[2] = src.nb[3] = 0; - sources.push_back(src); - } - } - - int64_t ne0[4] = {0}, ne1[4] = {0}, ne2[4] = {0}; - if (sources.size() > 0) { for (int d = 0; d < 4; d++) ne0[d] = sources[0].ne[d]; } - if (sources.size() > 1) { for (int d = 0; d < 4; d++) ne1[d] = sources[1].ne[d]; } - if (sources.size() > 2) { for (int d = 0; d < 4; d++) ne2[d] = sources[2].ne[d]; } - - int64_t ne_out[4] = {0, 0, 0, 0}; - if (!compute_output_ne(op, ne0, ne1, ne2, ne_out)) { - return nullptr; - } - - ggml_type out_type = GGML_TYPE_F32; - std::array op_params{}; - op_params.fill(0); - - if (op == GGML_OP_MUL_MAT_ID && sources.size() >= 2) { - op_params[0] = (int32_t)sources[1].ne[1]; - } else if ((op == GGML_OP_UNARY || op == GGML_OP_GLU) && agg.key.sub_op >= 0) { - op_params[0] = (int32_t)agg.key.sub_op; - } - - std::array out_ne; - for (int d = 0; d < 4; d++) out_ne[d] = ne_out[d]; - - return std::unique_ptr(new test_generic_op(op, out_type, out_ne, op_params, sources, - agg.key.name + " [from profile]")); - }; - - printf(" Loaded %d profiler ops, running top %d:\n", (int)sorted.size(), n); - for (int i = 0; i < n; i++) { - const auto & agg = sorted[i]; - double pct = 100.0 * agg.total_ns / global_max_ns; - int bid = agg.key.backend_id; - std::string bname = backend_names.count(bid) ? backend_names[bid] : std::to_string(bid); - printf(" #%d: %s @ %s %ldx %.3fms total (%.1f%% of top)\n", - i + 1, agg.key.name.c_str(), bname.c_str(), agg.count, - agg.total_ns / 1e6, pct); - if (!ne_is_zero(agg.key.ne_src0)) { - const char * tn = agg.key.type_src0 >= 0 ? ggml_type_name((ggml_type)agg.key.type_src0) : "?"; - printf(" src0: [%lld, %lld, %lld, %lld] (%s)\n", - (long long)agg.key.ne_src0[0], (long long)agg.key.ne_src0[1], - (long long)agg.key.ne_src0[2], (long long)agg.key.ne_src0[3], tn); - } - if (!ne_is_zero(agg.key.ne_src1)) { - const char * tn = agg.key.type_src1 >= 0 ? ggml_type_name((ggml_type)agg.key.type_src1) : "?"; - printf(" src1: [%lld, %lld, %lld, %lld] (%s)\n", - (long long)agg.key.ne_src1[0], (long long)agg.key.ne_src1[1], - (long long)agg.key.ne_src1[2], (long long)agg.key.ne_src1[3], tn); - } - if (!ne_is_zero(agg.key.ne_src2)) { - const char * tn = agg.key.type_src2 >= 0 ? ggml_type_name((ggml_type)agg.key.type_src2) : "?"; - printf(" src2: [%lld, %lld, %lld, %lld] (%s)\n", - (long long)agg.key.ne_src2[0], (long long)agg.key.ne_src2[1], - (long long)agg.key.ne_src2[2], (long long)agg.key.ne_src2[3], tn); - } - - auto tc = make_test_from_agg(agg); - if (!tc) continue; - - int bid2 = agg.key.backend_id; - std::string bname2 = backend_names.count(bid2) ? backend_names[bid2] : std::to_string(bid2); - - profile_test_plan::backend_plan * bp = nullptr; - for (auto & b : plan.backends) { - if (b.backend_id == bid2 && b.backend_name == bname2) { - bp = &b; - break; - } - } - if (!bp) { - plan.backends.push_back({bid2, bname2, {}, {}}); - bp = &plan.backends.back(); - } - bp->test_cases.push_back(std::move(tc)); - bp->aggs.push_back(agg); - } - - return plan; + printf(" --test-file reads test operators from a test file generated by llama-export-graph-ops or the profiler\n"); } int main(int argc, char ** argv) { test_mode mode = MODE_TEST; + bool mode_explicit = false; output_formats output_format = CONSOLE; const char * op_names_filter = nullptr; const char * backend_filter = nullptr; const char * params_filter = nullptr; const char * test_file_path = nullptr; - const char * profile_path = nullptr; - int profile_top_n = 10; for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "test") == 0) { mode = MODE_TEST; + mode_explicit = true; } else if (strcmp(argv[i], "perf") == 0) { mode = MODE_PERF; + mode_explicit = true; } else if (strcmp(argv[i], "grad") == 0) { mode = MODE_GRAD; + mode_explicit = true; } else if (strcmp(argv[i], "support") == 0) { mode = MODE_SUPPORT; + mode_explicit = true; } else if (strcmp(argv[i], "-o") == 0) { if (i + 1 < argc) { op_names_filter = argv[++i]; @@ -9737,34 +9331,19 @@ int main(int argc, char ** argv) { usage(argv); return 1; } - } else if (strcmp(argv[i], "--from-profile") == 0) { - if (i + 1 < argc) { - profile_path = argv[++i]; - } else { - usage(argv); - return 1; - } - } else if (strcmp(argv[i], "--top-n") == 0) { - if (i + 1 < argc) { - profile_top_n = atoi(argv[++i]); - if (profile_top_n <= 0) profile_top_n = 10; - } else { - usage(argv); - return 1; - } } else { usage(argv); return 1; } } - // load and enumerate backends - ggml_backend_load_all(); - - if (profile_path != nullptr) { + if (test_file_path != nullptr && !mode_explicit) { mode = MODE_PERF; } + // load and enumerate backends + ggml_backend_load_all(); + // Create printer for output format std::unique_ptr output_printer = create_printer(output_format); if (output_printer) { @@ -9773,83 +9352,6 @@ int main(int argc, char ** argv) { output_printer->print_testing_start(testing_start_info(ggml_backend_dev_count())); - if (profile_path != nullptr) { - profile_test_plan plan = make_test_plan_from_profile(profile_path, profile_top_n); - - size_t n_ok = 0; - size_t total = plan.backends.size(); - - for (size_t bi = 0; bi < plan.backends.size(); bi++) { - auto & bp = plan.backends[bi]; - - ggml_backend_dev_t dev = nullptr; - for (size_t i = 0; i < ggml_backend_dev_count(); i++) { - ggml_backend_dev_t d = ggml_backend_dev_get(i); - if (strcmp(ggml_backend_dev_name(d), bp.backend_name.c_str()) == 0) { - dev = d; - break; - } - } - - if (dev == nullptr) { - fprintf(stderr, "Warning: backend '%s' from profile not found, skipping\n", bp.backend_name.c_str()); - n_ok++; - output_printer->print_backend_init( - backend_init_info(bi, total, bp.backend_name.c_str(), true, "Not found")); - continue; - } - - if (backend_filter != NULL && strcmp(backend_filter, bp.backend_name.c_str()) != 0) { - output_printer->print_backend_init( - backend_init_info(bi, total, bp.backend_name.c_str(), true, "Skipping")); - n_ok++; - continue; - } - - ggml_backend_t backend = ggml_backend_dev_init(dev, NULL); - GGML_ASSERT(backend != NULL); - - ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (ggml_backend_set_n_threads_fn) { - ggml_backend_set_n_threads_fn(backend, N_THREADS); - } - - size_t free, total_mem; - ggml_backend_dev_memory(dev, &free, &total_mem); - output_printer->print_backend_init(backend_init_info(bi, plan.backends.size(), bp.backend_name.c_str(), - false, "", ggml_backend_dev_description(dev), - total_mem / 1024 / 1024, free / 1024 / 1024, true)); - - std::vector> cases; - for (auto & tc : bp.test_cases) { - cases.push_back(std::move(tc)); - } - - bool ok = test_backend(backend, MODE_PERF, op_names_filter, params_filter, - output_printer.get(), nullptr, std::move(cases)); - - if (ok) { - n_ok++; - } - output_printer->print_backend_status( - backend_status_info(ggml_backend_name(backend), ok ? test_status_t::OK : test_status_t::FAIL)); - - ggml_backend_free(backend); - } - - ggml_quantize_free(); - - if (output_printer) { - output_printer->print_footer(); - } - - output_printer->print_overall_summary( - overall_summary_info(n_ok, total, n_ok == total)); - - return n_ok != total; - } - size_t n_ok = 0; for (size_t i = 0; i < ggml_backend_dev_count(); i++) { @@ -9862,7 +9364,7 @@ int main(int argc, char ** argv) { continue; } - if (backend_filter == NULL && + if (backend_filter == NULL && test_file_path == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) { output_printer->print_backend_init(backend_init_info( i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping CPU backend")); diff --git a/tools/profiler/profiler.py b/tools/profiler/profiler.py index 0cc3f81972..1a8d3bd550 100644 --- a/tools/profiler/profiler.py +++ b/tools/profiler/profiler.py @@ -49,6 +49,107 @@ GGML_GLU_OP_NAMES = { 4: "GEGLU_QUICK", 5: "SWIGLU_OAI", } +GGML_OP_NAMES = { + 0: "NONE", 1: "DUP", 2: "ADD", 3: "ADD_ID", 4: "ADD1", + 5: "ACC", 6: "SUB", 7: "MUL", 8: "DIV", 9: "SQR", + 10: "SQRT", 11: "LOG", 12: "SIN", 13: "COS", 14: "SUM", + 15: "SUM_ROWS", 16: "CUMSUM", 17: "MEAN", 18: "ARGMAX", + 19: "COUNT_EQUAL", 20: "REPEAT", 21: "REPEAT_BACK", 22: "CONCAT", + 23: "SILU_BACK", 24: "NORM", 25: "RMS_NORM", 26: "RMS_NORM_BACK", + 27: "GROUP_NORM", 28: "L2_NORM", 29: "MUL_MAT", 30: "MUL_MAT_ID", + 31: "OUT_PROD", 32: "SCALE", 33: "SET", 34: "CPY", 35: "CONT", + 36: "RESHAPE", 37: "VIEW", 38: "PERMUTE", 39: "TRANSPOSE", + 40: "GET_ROWS", 41: "GET_ROWS_BACK", 42: "SET_ROWS", 43: "DIAG", + 44: "DIAG_MASK_INF", 45: "DIAG_MASK_ZERO", 46: "SOFT_MAX", + 47: "SOFT_MAX_BACK", 48: "ROPE", 49: "ROPE_BACK", 50: "CLAMP", + 51: "CONV_TRANSPOSE_1D", 52: "IM2COL", 53: "IM2COL_BACK", 54: "IM2COL_3D", + 55: "CONV_2D", 56: "CONV_3D", 57: "CONV_2D_DW", 58: "CONV_TRANSPOSE_2D", + 59: "POOL_1D", 60: "POOL_2D", 61: "POOL_2D_BACK", 62: "UPSCALE", + 63: "PAD", 64: "PAD_REFLECT_1D", 65: "ROLL", 66: "ARANGE", + 67: "TIMESTEP_EMBEDDING", 68: "ARGSORT", 69: "TOP_K", 70: "LEAKY_RELU", + 71: "TRI", 72: "FILL", 73: "FLASH_ATTN_EXT", 74: "FLASH_ATTN_BACK", + 75: "SSM_CONV", 76: "SSM_SCAN", 77: "WIN_PART", 78: "WIN_UNPART", + 79: "GET_REL_POS", 80: "ADD_REL_POS", 81: "RWKV_WKV6", + 82: "GATED_LINEAR_ATTN", 83: "RWKV_WKV7", 84: "SOLVE_TRI", + 85: "GATED_DELTA_NET", 86: "UNARY", 87: "MAP_CUSTOM1", + 88: "MAP_CUSTOM2", 89: "MAP_CUSTOM3", 90: "CUSTOM", + 91: "CROSS_ENTROPY_LOSS", 92: "CROSS_ENTROPY_LOSS_BACK", + 93: "OPT_STEP_ADAMW", 94: "OPT_STEP_SGD", 95: "GLU", + 96: "COUNT", +} + +GGML_TYPE_NAMES_TO_ID = {v: k for k, v in GGML_TYPE_NAMES.items()} + +GGML_OP_NAMES_TO_ID = {v: k for k, v in GGML_OP_NAMES.items()} + + +_EXPORT_SKIP_OPS = frozenset({ + 33, # SET + 34, # CPY + 35, # CONT + 36, # RESHAPE + 37, # VIEW + 38, # PERMUTE + 39, # TRANSPOSE + 41, # GET_ROWS_BACK + 42, # SET_ROWS + 43, # DIAG + 44, # DIAG_MASK_INF + 45, # DIAG_MASK_ZERO + 47, # SOFT_MAX_BACK + 49, # ROPE_BACK + 51, # CONV_TRANSPOSE_1D + 52, # IM2COL + 53, # IM2COL_BACK + 54, # IM2COL_3D + 58, # CONV_TRANSPOSE_2D + 61, # POOL_2D_BACK + 63, # PAD + 64, # PAD_REFLECT_1D + 65, # ROLL + 66, # ARANGE + 70, # LEAKY_RELU (covered by UNARY) + 71, # TRI + 72, # FILL + 77, # WIN_PART + 78, # WIN_UNPART + 92, # CROSS_ENTROPY_LOSS_BACK + 93, # OPT_STEP_ADAMW + 94, # OPT_STEP_SGD + 96, # COUNT +}) + + +def _compute_output_ne(op_id: int, ne0: list, ne1: list, ne2: list) -> list | None: + if op_id == 29: # MUL_MAT + return [ne0[1], ne1[1], max(ne0[2], ne1[2]), max(ne0[3], ne1[3])] + if op_id == 30: # MUL_MAT_ID + return [ne0[1], ne2[0], ne1[2], 1] + if op_id in (2, 7, 8, 32): # ADD, MUL, DIV, SCALE + return [max(ne0[i], ne1[i]) for i in range(4)] + if op_id == 4: # ADD1 + return list(ne0) + if op_id in (9, 86): # SQR, UNARY + return list(ne0) + if op_id in (46, 25): # SOFT_MAX, RMS_NORM + return list(ne0) + if op_id == 73: # FLASH_ATTN_EXT + return [ne1[1], ne1[1], ne0[2], ne0[3]] + if op_id == 40: # GET_ROWS + return [ne0[0], ne1[1], ne1[2], ne1[3]] + if op_id == 41: # GET_ROWS_BACK + return [ne0[0], ne1[1], ne1[2], ne1[3]] + if op_id == 42: # SET_ROWS + return list(ne0) + if op_id == 31: # OUT_PROD + return [ne0[0], ne1[0], max(ne0[2], ne1[2]), max(ne0[3], ne1[3])] + if op_id == 22: # CONCAT + return [ne0[0] + ne1[0], max(ne0[1], ne1[1]), + max(ne0[2], ne1[2]), max(ne0[3], ne1[3])] + if op_id in (34, 35, 50, 60, 53, 68): # CPY, CONT, CLAMP, POOL_2D, IM2COL_BACK, ARGSORT + return list(ne0) + return None + @dataclass class ProfileRecord: @@ -464,6 +565,92 @@ class ProfileData: print(f"Chrome trace exported to: {filepath}") print(f"Open chrome://tracing in Chrome/Edge and load this file.") + def export_graph_ops(self, filepath: str | Path) -> None: + """Export operations in export-graph-ops format for test-backend-ops --test-file.""" + seen: set[tuple] = set() + lines: list[str] = [] + + backend_by_id: dict[int, dict] = {} + for b in self.metadata.get("backends", []): + backend_by_id[b["id"]] = b + + for rec in self.records: + if rec.type != OP_EVENT: + continue + + op_id = GGML_OP_NAMES_TO_ID.get(rec.name, -1) + if op_id < 0: + continue + + if op_id in _EXPORT_SKIP_OPS: + continue + + ne0 = rec.ne_src0 + ne1 = rec.ne_src1 + ne2 = rec.ne_src2 + + type_src0 = rec.type_src0 if rec.type_src0 >= 0 else 0 + type_src1 = rec.type_src1 if rec.type_src1 >= 0 else 0 + type_src2 = rec.type_src2 if rec.type_src2 >= 0 else 0 + + sources: list[tuple[int, list, list]] = [] + if any(v != 0 for v in ne0): + sources.append((type_src0, ne0, [0, 0, 0, 0])) + if any(v != 0 for v in ne1): + sources.append((type_src1, ne1, [0, 0, 0, 0])) + + if op_id == 30: # MUL_MAT_ID: ensure rows tensor (src2) is present + if len(sources) < 3 and any(v != 0 for v in ne2): + sources.append((type_src2, ne2, [0, 0, 0, 0])) + elif len(sources) < 3 and len(sources) >= 2: + sources.append((24, [sources[1][1][1], 1, 1, 1], [0, 0, 0, 0])) # I32 + elif any(v != 0 for v in ne2): + sources.append((type_src2, ne2, [0, 0, 0, 0])) + + src_ne0 = sources[0][1] if len(sources) > 0 else [0, 0, 0, 0] + src_ne1 = sources[1][1] if len(sources) > 1 else [0, 0, 0, 0] + src_ne2 = sources[2][1] if len(sources) > 2 else [0, 0, 0, 0] + + ne_out = _compute_output_ne(op_id, src_ne0, src_ne1, src_ne2) + if ne_out is None: + continue + + op_params: list[int] = [] + if op_id == 30 and len(sources) >= 2: # MUL_MAT_ID + op_params.append(sources[1][1][1]) + elif op_id in (86, 95) and rec.sub_op >= 0: # UNARY, GLU + op_params.append(rec.sub_op) + + bname = "" + if rec.backend_id in backend_by_id: + bname = backend_by_id[rec.backend_id].get("device", "") + if not bname or bname == "unknown": + bname = backend_by_id[rec.backend_id].get("name", "") + + key = (op_id, tuple(ne_out), tuple(op_params), tuple((s[0], tuple(s[1])) for s in sources), bname) + if key in seen: + continue + seen.add(key) + + line = f"{op_id} 0 {ne_out[0]} {ne_out[1]} {ne_out[2]} {ne_out[3]} " + line += f"{len(op_params)}" + for p in op_params: + line += f" {p}" + line += f" {len(sources)}" + for src_type, src_ne, src_nb in sources: + line += f" {src_type} {src_ne[0]} {src_ne[1]} {src_ne[2]} {src_ne[3]} {src_nb[0]} {src_nb[1]} {src_nb[2]} {src_nb[3]}" + name = rec.name if rec.name else "-" + line += f" {name}" + if bname: + line += f" {bname}" + line += "\n" + lines.append(line) + + with open(filepath, "w") as f: + f.writelines(lines) + + print(f"Exported {len(lines)} unique ops to: {filepath}") + def export_html_viewer(self, filepath: str | Path, max_records: int = 0) -> None: """Export a self-contained interactive HTML timeline viewer using Canvas.""" import json as json_mod @@ -1007,6 +1194,7 @@ Examples: python -m tools.profiler.profiler profile.json python -m tools.profiler.profiler profile.json --chrome-trace trace.json python -m tools.profiler.profiler profile.json --top-ops 20 + python -m tools.profiler.profiler profile.json --export-ops ops.txt """, ) parser.add_argument("profile", help="Path to profiler JSON file") @@ -1014,6 +1202,8 @@ Examples: help="Export as Chrome Trace Event format") parser.add_argument("--html-viewer", metavar="FILE", help="Export as interactive HTML timeline viewer") + parser.add_argument("--export-ops", metavar="FILE", + help="Export ops in export-graph-ops format (for test-backend-ops --test-file)") parser.add_argument("--html-max-records", type=int, default=0, help="Max records in HTML viewer (0=unlimited, set to downsample for huge traces)") parser.add_argument("--top-ops", type=int, default=0, @@ -1033,6 +1223,9 @@ Examples: if args.html_viewer: data.export_html_viewer(args.html_viewer, max_records=args.html_max_records) + if args.export_ops: + data.export_graph_ops(args.export_ops) + if args.top_ops > 0: print(f"\nTop {args.top_ops} operations by total time:\n") for s in data.top_operations(args.top_ops): @@ -1055,7 +1248,7 @@ Examples: f"{s.count:>6} calls {s.total_bytes / 1e6:.1f} MB") print() - if args.top_ops == 0 and args.top_kernels == 0 and not args.inefficiency and not args.chrome_trace and not args.html_viewer: + if args.top_ops == 0 and args.top_kernels == 0 and not args.inefficiency and not args.chrome_trace and not args.html_viewer and not args.export_ops: data.summary()