diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 6962e9a74c..9e667721c5 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -261,6 +261,9 @@ extern "C" { GGML_API enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync); + // temporary workaround to statically allocate tensors from a context in a deduplicated way: + GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); + // // Backend registry // diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 605cc6976d..41419b617b 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -1,6 +1,5 @@ #include "ggml-alloc.h" #include "ggml-backend-impl.h" -#include "ggml-backend.h" #include "ggml.h" #include "ggml-impl.h" #include @@ -1241,9 +1240,6 @@ size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, } ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { - if (ggml_backend_buft_is_meta(buft)) { - return ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx, buft); - } size_t nbytes_total = 0; return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false); } diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 6d92a9c06d..cf26d580e0 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -254,9 +254,6 @@ extern "C" { # define GGML_BACKEND_DL_SCORE_IMPL(score_fn) #endif - // temporary workaround to statically allocate tensors from a context in a deduplicated way: - GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); - #ifdef __cplusplus } #endif diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index cb702b2a59..5920aa9a1d 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -187,7 +187,11 @@ llama_kv_cache::llama_kv_cache( t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it } } else { - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer + if (ggml_backend_buft_is_meta(buft)) { + buf = ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx.get(), buft); + } else { + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer + } } if (!buf) { throw std::runtime_error("failed to allocate buffer for kv cache"); diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index f0038036dc..9d040da4b1 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1,5 +1,6 @@ #include "llama-memory-recurrent.h" +#include "ggml-backend.h" #include "llama-impl.h" #include "llama-io.h" #include "llama-batch.h" @@ -101,7 +102,8 @@ llama_memory_recurrent::llama_memory_recurrent( // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto & [buft, ctx] : ctx_map) { - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); + ggml_backend_buffer_t buf = ggml_backend_buft_is_meta(buft) ? + ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx.get(), buft) : ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for rs cache"); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bffd4eb99e..9376ea5631 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7504,7 +7504,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them } } else { - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + if (ggml_backend_buft_is_meta(buft)) { + buf = ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + } else { + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + } } if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));