diff --git a/tests/export-graph-ops.cpp b/tests/export-graph-ops.cpp index e37855eee6..2d75a27960 100644 --- a/tests/export-graph-ops.cpp +++ b/tests/export-graph-ops.cpp @@ -159,7 +159,7 @@ int main(int argc, char ** argv) { hf_quant = "Q4_K_M"; } - gguf_context * gguf_ctx = gguf_fetch_gguf_ctx(hf_repo, hf_quant); + gguf_context_ptr gguf_ctx = gguf_fetch_gguf_ctx(hf_repo, hf_quant); if (!gguf_ctx) { LOG_ERR("failed to fetch GGUF metadata from %s\n", hf_repo.c_str()); return 1; @@ -168,8 +168,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = llama_model_default_params(); model_params.devices = params.devices.data(); - model.reset(llama_model_init_from_user(gguf_ctx, set_tensor_data, nullptr, model_params)); - gguf_free(gguf_ctx); + model.reset(llama_model_init_from_user(gguf_ctx.get(), set_tensor_data, nullptr, model_params)); if (!model) { LOG_ERR("failed to create llama_model from %s\n", hf_repo.c_str()); diff --git a/tests/gguf-model-data.cpp b/tests/gguf-model-data.cpp index 343f86d918..adfd6bec68 100644 --- a/tests/gguf-model-data.cpp +++ b/tests/gguf-model-data.cpp @@ -4,6 +4,7 @@ #include "gguf-model-data.h" #include "common.h" +#include "ggml-cpp.h" #include "gguf.h" #include @@ -616,7 +617,7 @@ std::optional gguf_fetch_model_meta( return model_opt; } -gguf_context * gguf_fetch_gguf_ctx( +gguf_context_ptr gguf_fetch_gguf_ctx( const std::string & repo, const std::string & quant, const std::string & cache_dir) { @@ -640,13 +641,14 @@ gguf_context * gguf_fetch_gguf_ctx( const std::string cache_path = get_cache_file_path(cdir, repo_part, filename); - ggml_context * ggml_ctx; + ggml_context_ptr ggml_ctx_ptr; + ggml_context * ggml_ctx{}; gguf_init_params params{true, &ggml_ctx}; - gguf_context * ctx = gguf_init_from_file(cache_path.c_str(), params); + gguf_context_ptr ctx{gguf_init_from_file(cache_path.c_str(), params)}; + ggml_ctx_ptr.reset(ggml_ctx); if (ctx == nullptr) { fprintf(stderr, "gguf_fetch: gguf_init_from_file failed\n"); - ggml_free(ggml_ctx); return nullptr; } @@ -654,8 +656,6 @@ gguf_context * gguf_fetch_gguf_ctx( if (model.n_split > 1) { if (split_prefix.empty()) { fprintf(stderr, "gguf_fetch: model reports %u splits but filename has no split pattern\n", model.n_split); - gguf_free(ctx); - ggml_free(ggml_ctx); return nullptr; } @@ -671,37 +671,29 @@ gguf_context * gguf_fetch_gguf_ctx( auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part); if (!shard.has_value()) { fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str()); - gguf_free(ctx); - ggml_free(ggml_ctx); return nullptr; } // Load tensors from shard and add to main gguf_context const std::string shard_path = get_cache_file_path(cdir, repo_part, shard_name); - ggml_context * shard_ggml_ctx; + ggml_context_ptr shard_ggml_ctx_ptr; + ggml_context * shard_ggml_ctx{}; gguf_init_params shard_params{true, &shard_ggml_ctx}; - gguf_context * shard_ctx = gguf_init_from_file(shard_path.c_str(), shard_params); + gguf_context_ptr shard_ctx{gguf_init_from_file(shard_path.c_str(), shard_params)}; + shard_ggml_ctx_ptr.reset(shard_ggml_ctx); if (shard_ctx == nullptr) { fprintf(stderr, "gguf_fetch: shard gguf_init_from_file failed\n"); - ggml_free(shard_ggml_ctx); - gguf_free(ctx); - ggml_free(ggml_ctx); return nullptr; } for (ggml_tensor * t = ggml_get_first_tensor(shard_ggml_ctx); t; t = ggml_get_next_tensor(shard_ggml_ctx, t)) { - gguf_add_tensor(ctx, t); + gguf_add_tensor(ctx.get(), t); } - - gguf_free(shard_ctx); - ggml_free(shard_ggml_ctx); } - gguf_set_val_u16(ctx, "split.count", 1); + gguf_set_val_u16(ctx.get(), "split.count", 1); } - ggml_free(ggml_ctx); - return ctx; } diff --git a/tests/gguf-model-data.h b/tests/gguf-model-data.h index 9c2ff02513..61ce24bb05 100644 --- a/tests/gguf-model-data.h +++ b/tests/gguf-model-data.h @@ -1,6 +1,6 @@ #pragma once -#include "ggml.h" +#include "ggml-cpp.h" #include "gguf.h" #include @@ -42,7 +42,7 @@ std::optional gguf_fetch_model_meta( const std::string & quant = "Q8_0", const std::string & cache_dir = ""); // empty = default -gguf_context * gguf_fetch_gguf_ctx( +gguf_context_ptr gguf_fetch_gguf_ctx( const std::string & repo, const std::string & quant = "Q8_0", const std::string & cache_dir = "");