tests: allow exporting graph ops from HF file without downloading weights

This commit is contained in:
Ruben Ortlam 2026-03-30 14:54:46 +02:00
parent 389c7d4955
commit 829c250073
5 changed files with 177 additions and 11 deletions

View File

@ -537,9 +537,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
} catch (const std::exception & e) {
LOG_WRN("HF cache migration failed: %s\n", e.what());
}
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
// maybe handle remote preset
if (!params.model.hf_repo.empty()) {
if (!params.model.hf_repo.empty() && !skip_model_download) {
std::string cli_hf_repo = params.model.hf_repo;
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
@ -570,7 +572,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
// handle model and download
{
if (!skip_model_download) {
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
if (params.no_mmproj) {
params.mmproj = {};
@ -591,7 +593,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) {
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download&& !params.usage && !params.completion) {
throw std::invalid_argument("error: --model is required\n");
}

View File

@ -287,3 +287,7 @@ target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
llama_build(export-graph-ops.cpp)
target_include_directories(export-graph-ops PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
if (TARGET gguf-model-data)
target_link_libraries(export-graph-ops PRIVATE gguf-model-data)
target_compile_definitions(export-graph-ops PRIVATE LLAMA_HF_FETCH)
endif()

View File

@ -1,15 +1,26 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include "llama-cpp.h"
#include "../src/llama-ext.h"
#include "ggml.h"
#include "gguf-model-data.h"
#include "gguf.h"
#include "ggml-backend.h"
#include "download.h"
#include <array>
#include <vector>
#include <set>
#include <fstream>
#include <iostream>
#include <random>
// Noop because weights are not needed
static void set_tensor_data(struct ggml_tensor * tensor, void * userdata) {
GGML_UNUSED(tensor);
GGML_UNUSED(userdata);
}
struct input_tensor {
ggml_type type;
@ -132,9 +143,50 @@ int main(int argc, char ** argv) {
params.warmup = false;
auto init_result = common_init_from_params(params);
llama_context * ctx;
common_init_result_ptr init_result;
llama_model_ptr model;
llama_context * ctx = init_result->context();
if (params.model.hf_repo.empty()) {
init_result = common_init_from_params(params);
ctx = init_result->context();
} else {
#ifdef LLAMA_HF_FETCH
auto [hf_repo, hf_quant] = common_download_split_repo_tag(params.model.hf_repo);
if (hf_quant.empty() || hf_quant == "latest") {
hf_quant = "Q4_K_M";
}
gguf_context * gguf_ctx = gguf_fetch_gguf_ctx(hf_repo, hf_quant);
if (!gguf_ctx) {
LOG_ERR("failed to fetch GGUF metadata from %s\n", hf_repo.c_str());
return 1;
}
llama_model_params model_params = llama_model_default_params();
model_params.devices = params.devices.data();
model.reset(llama_model_init_from_user(gguf_ctx, set_tensor_data, nullptr, model_params));
gguf_free(gguf_ctx);
if (!model) {
LOG_ERR("failed to create llama_model from %s\n", hf_repo.c_str());
return 1;
}
llama_context_params ctx_params = llama_context_default_params();
ctx = llama_init_from_model(model.get(), ctx_params);
if (!ctx) {
LOG_ERR("failed to create llama_context\n");
return 1;
}
#else
LOG_ERR("export-graph-ops compiled without HF fetch support\n");
return 1;
#endif
}
const uint32_t n_seqs = llama_n_seq_max(ctx);
const uint32_t n_tokens = std::min(llama_n_ctx(ctx), llama_n_ubatch(ctx));
@ -143,13 +195,15 @@ int main(int argc, char ** argv) {
auto * gf_pp = llama_graph_reserve(ctx, n_tokens, n_seqs, n_tokens);
if (!gf_pp) {
throw std::runtime_error("failed to reserve prompt processing graph");
LOG_ERR("failed to reserve prompt processing graph\n");
return 1;
}
extract_graph_ops(gf_pp, "pp", tests);
auto * gf_tg = llama_graph_reserve(ctx, n_seqs, n_seqs, n_seqs);
if (!gf_tg) {
throw std::runtime_error("failed to reserve token generation graph");
LOG_ERR("failed to reserve token generation graph\n");
return 1;
}
extract_graph_ops(gf_tg, "tg", tests);
@ -158,12 +212,18 @@ int main(int argc, char ** argv) {
std::ofstream f(params.out_file);
if (!f.is_open()) {
throw std::runtime_error("Unable to open output file");
LOG_ERR("unable to open output file: %s\n", params.out_file.c_str());
return 1;
}
for (const auto& test : tests) {
test.serialize(f);
}
if (!params.model.hf_repo.empty()) {
// Context is not owned by common_init_result in this case
llama_free(ctx);
}
return 0;
}

View File

@ -531,14 +531,18 @@ static std::optional<gguf_remote_model> fetch_and_parse(
return std::nullopt;
}
static std::string get_cache_file_path(const std::string& cdir, const std::string& repo_part, const std::string& filename) {
std::string fname_part = sanitize_for_path(filename);
return cdir + "/" + repo_part + "--" + fname_part + ".partial";
}
// Try cache first, then fetch and parse a single GGUF shard.
static std::optional<gguf_remote_model> fetch_or_cached(
const std::string & repo,
const std::string & filename,
const std::string & cdir,
const std::string & repo_part) {
std::string fname_part = sanitize_for_path(filename);
std::string cache_path = cdir + "/" + repo_part + "--" + fname_part + ".partial";
std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
{
std::vector<char> cached;
@ -611,3 +615,93 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
return model_opt;
}
gguf_context * gguf_fetch_gguf_ctx(
const std::string & repo,
const std::string & quant,
const std::string & cache_dir) {
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
std::string repo_part = sanitize_for_path(repo);
std::string split_prefix;
std::string filename = detect_gguf_filename(repo, quant, split_prefix);
if (filename.empty()) {
return nullptr;
}
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
if (!model_opt.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
return nullptr;
}
auto & model = model_opt.value();
const std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
ggml_context * ggml_ctx;
gguf_init_params params{true, &ggml_ctx};
gguf_context * ctx = gguf_init_from_file(cache_path.c_str(), params);
if (ctx == nullptr) {
fprintf(stderr, "gguf_fetch: gguf_init_from_file failed\n");
ggml_free(ggml_ctx);
return nullptr;
}
// If the model is split across multiple files we need to fetch the remaining shards metadata
if (model.n_split > 1) {
if (split_prefix.empty()) {
fprintf(stderr, "gguf_fetch: model reports %u splits but filename has no split pattern\n", model.n_split);
gguf_free(ctx);
ggml_free(ggml_ctx);
return nullptr;
}
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
for (int i = 2; i <= model.n_split; i++) {
char num_buf[6], total_buf[6];
snprintf(num_buf, sizeof(num_buf), "%05d", i);
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
if (!shard.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
gguf_free(ctx);
ggml_free(ggml_ctx);
return nullptr;
}
// Load tensors from shard and add to main gguf_context
const std::string shard_path = get_cache_file_path(cdir, repo_part, shard_name);
ggml_context * shard_ggml_ctx;
gguf_init_params shard_params{true, &shard_ggml_ctx};
gguf_context * shard_ctx = gguf_init_from_file(shard_path.c_str(), shard_params);
if (shard_ctx == nullptr) {
fprintf(stderr, "gguf_fetch: shard gguf_init_from_file failed\n");
ggml_free(shard_ggml_ctx);
gguf_free(ctx);
ggml_free(ggml_ctx);
return nullptr;
}
for (ggml_tensor * t = ggml_get_first_tensor(shard_ggml_ctx); t; t = ggml_get_next_tensor(shard_ggml_ctx, t)) {
gguf_add_tensor(ctx, t);
}
gguf_free(shard_ctx);
ggml_free(shard_ggml_ctx);
}
gguf_set_val_u16(ctx, "split.count", 1);
}
ggml_free(ggml_ctx);
return ctx;
}

View File

@ -1,6 +1,7 @@
#pragma once
#include "ggml.h"
#include "gguf.h"
#include <cstdint>
#include <optional>
@ -40,3 +41,8 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
const std::string & repo,
const std::string & quant = "Q8_0",
const std::string & cache_dir = ""); // empty = default
gguf_context * gguf_fetch_gguf_ctx(
const std::string & repo,
const std::string & quant = "Q8_0",
const std::string & cache_dir = "");