llama: automatically set parameters not set by the user in such a way that maximizes GPU utilization (#16653)
* llama: automatically fit args to free memory llama-fit-params tool * fix CI * hints for bug reports, ensure no reallocation * fix segfault with Vulkan * add llama-fit-params to CI * fix CI * fix CI * fix CI * minor adjustments * fix assignment of 1 dense layer * fix logger not being reset on model load failure * remove --n-gpu-layer hint on model load failure * fix llama-fit-params verbosity * fix edge case * fix typo [no ci]
This commit is contained in:
parent
4aced7a631
commit
b1f3a6e5db
|
|
@ -11,7 +11,7 @@ body:
|
||||||
(i.e. the generated text) are incorrect or llama.cpp crashes during model evaluation.
|
(i.e. the generated text) are incorrect or llama.cpp crashes during model evaluation.
|
||||||
If you encountered the issue while using an external UI (e.g. ollama),
|
If you encountered the issue while using an external UI (e.g. ollama),
|
||||||
please reproduce your issue using one of the examples/binaries in this repository.
|
please reproduce your issue using one of the examples/binaries in this repository.
|
||||||
The `llama-cli` binary can be used for simple and reproducible model inference.
|
The `llama-completion` binary can be used for simple and reproducible model inference.
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: version
|
id: version
|
||||||
attributes:
|
attributes:
|
||||||
|
|
@ -74,9 +74,12 @@ body:
|
||||||
Please give us a summary of the problem and tell us how to reproduce it.
|
Please give us a summary of the problem and tell us how to reproduce it.
|
||||||
If you can narrow down the bug to specific hardware, compile flags, or command line arguments,
|
If you can narrow down the bug to specific hardware, compile flags, or command line arguments,
|
||||||
that information would be very much appreciated by us.
|
that information would be very much appreciated by us.
|
||||||
|
|
||||||
|
If possible, please try to reproduce the issue using `llama-completion` with `-fit off`.
|
||||||
|
If you can only reproduce the issue with `-fit on`, please provide logs both with and without `--verbose`.
|
||||||
placeholder: >
|
placeholder: >
|
||||||
e.g. when I run llama-cli with -ngl 99 I get garbled outputs.
|
e.g. when I run llama-completion with `-fa on` I get garbled outputs for very long prompts.
|
||||||
When I use -ngl 0 it works correctly.
|
With short prompts or `-fa off` it works correctly.
|
||||||
Here are the exact commands that I used: ...
|
Here are the exact commands that I used: ...
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
|
|
||||||
|
|
@ -398,6 +398,8 @@ function gg_run_qwen3_0_6b {
|
||||||
./bin/llama-quantize ${model_bf16} ${model_q5_k} q5_k $(nproc)
|
./bin/llama-quantize ${model_bf16} ${model_q5_k} q5_k $(nproc)
|
||||||
./bin/llama-quantize ${model_bf16} ${model_q6_k} q6_k $(nproc)
|
./bin/llama-quantize ${model_bf16} ${model_q6_k} q6_k $(nproc)
|
||||||
|
|
||||||
|
(time ./bin/llama-fit-params --model ${model_f16} 2>&1 | tee -a $OUT/${ci}-fp-f16.log)
|
||||||
|
|
||||||
(time ./bin/llama-completion -no-cnv --model ${model_f16} -ngl 99 -c 1024 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
|
(time ./bin/llama-completion -no-cnv --model ${model_f16} -ngl 99 -c 1024 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
|
||||||
(time ./bin/llama-completion -no-cnv --model ${model_bf16} -ngl 99 -c 1024 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-bf16.log
|
(time ./bin/llama-completion -no-cnv --model ${model_bf16} -ngl 99 -c 1024 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-bf16.log
|
||||||
(time ./bin/llama-completion -no-cnv --model ${model_q8_0} -ngl 99 -c 1024 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
|
(time ./bin/llama-completion -no-cnv --model ${model_q8_0} -ngl 99 -c 1024 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
|
||||||
|
|
@ -523,6 +525,8 @@ function gg_run_embd_bge_small {
|
||||||
|
|
||||||
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
|
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
|
||||||
|
|
||||||
|
(time ./bin/llama-fit-params --model ${model_f16} 2>&1 | tee -a $OUT/${ci}-fp-f16.log)
|
||||||
|
|
||||||
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
|
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
|
||||||
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
|
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
|
||||||
|
|
||||||
|
|
@ -563,6 +567,8 @@ function gg_run_rerank_tiny {
|
||||||
|
|
||||||
model_f16="${path_models}/ggml-model-f16.gguf"
|
model_f16="${path_models}/ggml-model-f16.gguf"
|
||||||
|
|
||||||
|
(time ./bin/llama-fit-params --model ${model_f16} 2>&1 | tee -a $OUT/${ci}-fp-f16.log)
|
||||||
|
|
||||||
# for this model, the SEP token is "</s>"
|
# for this model, the SEP token is "</s>"
|
||||||
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
|
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cinttypes>
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
@ -529,7 +530,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||||
params.kv_overrides.back().key[0] = 0;
|
params.kv_overrides.back().key[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params.tensor_buft_overrides.empty()) {
|
// pad tensor_buft_overrides for llama_params_fit:
|
||||||
|
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||||
|
while (params.tensor_buft_overrides.size() < ntbo) {
|
||||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2153,6 +2156,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_MAIN_GPU"));
|
).set_env("LLAMA_ARG_MAIN_GPU"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{ "-fit", "--fit" }, "[on|off]",
|
||||||
|
string_format("whether to adjust unset arguments to fit in device memory ('on' or 'off', default: '%s')", params.fit_params ? "on" : "off"),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
if (is_truthy(value)) {
|
||||||
|
params.fit_params = true;
|
||||||
|
} else if (is_falsey(value)) {
|
||||||
|
params.fit_params = false;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
string_format("error: unkown value for --fit: '%s'\n", value.c_str()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).set_env("LLAMA_ARG_FIT"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{ "-fitt", "--fit-target" }, "MiB",
|
||||||
|
string_format("target margin per device for --fit option, default: %zu", params.fit_params_target/(1024*1024)),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.fit_params_target = value * size_t(1024*1024);
|
||||||
|
}
|
||||||
|
).set_env("LLAMA_ARG_FIT_TARGET"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{ "-fitc", "--fit-ctx" }, "N",
|
||||||
|
string_format("minimum ctx size that can be set by --fit option, default: %" PRIu32, params.fit_params_min_ctx),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.fit_params_min_ctx = value;
|
||||||
|
}
|
||||||
|
).set_env("LLAMA_ARG_FIT_CTX"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--check-tensors"},
|
{"--check-tensors"},
|
||||||
string_format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"),
|
string_format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"),
|
||||||
|
|
|
||||||
|
|
@ -1088,7 +1088,15 @@ struct common_init_result::impl {
|
||||||
|
|
||||||
common_init_result::common_init_result(common_params & params) :
|
common_init_result::common_init_result(common_params & params) :
|
||||||
pimpl(new impl{}) {
|
pimpl(new impl{}) {
|
||||||
const auto mparams = common_model_params_to_llama(params);
|
auto mparams = common_model_params_to_llama(params);
|
||||||
|
auto cparams = common_context_params_to_llama(params);
|
||||||
|
|
||||||
|
if (params.fit_params) {
|
||||||
|
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
|
||||||
|
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||||
|
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
||||||
|
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
|
|
@ -1103,8 +1111,6 @@ common_init_result::common_init_result(common_params & params) :
|
||||||
// TODO: fix naming
|
// TODO: fix naming
|
||||||
common_init_sampler_from_model(model, params.sampling);
|
common_init_sampler_from_model(model, params.sampling);
|
||||||
|
|
||||||
auto cparams = common_context_params_to_llama(params);
|
|
||||||
|
|
||||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||||
params.sampling.ignore_eos = false;
|
params.sampling.ignore_eos = false;
|
||||||
|
|
@ -1143,8 +1149,7 @@ common_init_result::common_init_result(common_params & params) :
|
||||||
|
|
||||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||||
if (lctx == NULL) {
|
if (lctx == NULL) {
|
||||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||||
__func__, params.model.path.c_str());
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1176,15 +1181,13 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||||
|
|
||||||
llama_model * model = res->model();
|
llama_model * model = res->model();
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||||
__func__, params.model.path.c_str());
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_context * lctx = res->context();
|
llama_context * lctx = res->context();
|
||||||
if (lctx == NULL) {
|
if (lctx == NULL) {
|
||||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||||
__func__, params.model.path.c_str());
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,7 @@ enum llama_example {
|
||||||
LLAMA_EXAMPLE_TTS,
|
LLAMA_EXAMPLE_TTS,
|
||||||
LLAMA_EXAMPLE_DIFFUSION,
|
LLAMA_EXAMPLE_DIFFUSION,
|
||||||
LLAMA_EXAMPLE_FINETUNE,
|
LLAMA_EXAMPLE_FINETUNE,
|
||||||
|
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||||
|
|
||||||
LLAMA_EXAMPLE_COUNT,
|
LLAMA_EXAMPLE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
@ -306,8 +307,8 @@ struct lr_opt {
|
||||||
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
||||||
|
|
||||||
struct common_params {
|
struct common_params {
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
|
||||||
int32_t n_ctx = 4096; // context size
|
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
|
||||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
|
|
@ -331,6 +332,9 @@ struct common_params {
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||||
|
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
||||||
|
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
|
||||||
|
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
||||||
|
|
||||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,14 @@ GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
|
||||||
// call with a worst-case graph to avoid buffer reallocations
|
// call with a worst-case graph to avoid buffer reallocations
|
||||||
// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
|
// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
|
||||||
// returns false if the buffer allocation failed
|
// returns false if the buffer allocation failed
|
||||||
|
// ggml_gallocr_resrve_n_size writes the buffer sizes per galloc buffer that would be allocated by ggml_gallocr_reserve_n to sizes
|
||||||
GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
|
GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
|
||||||
|
GGML_API void ggml_gallocr_reserve_n_size(
|
||||||
|
ggml_gallocr_t galloc,
|
||||||
|
struct ggml_cgraph * graph,
|
||||||
|
const int * node_buffer_ids,
|
||||||
|
const int * leaf_buffer_ids,
|
||||||
|
size_t * sizes);
|
||||||
GGML_API bool ggml_gallocr_reserve_n(
|
GGML_API bool ggml_gallocr_reserve_n(
|
||||||
ggml_gallocr_t galloc,
|
ggml_gallocr_t galloc,
|
||||||
struct ggml_cgraph * graph,
|
struct ggml_cgraph * graph,
|
||||||
|
|
@ -68,6 +75,8 @@ GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_i
|
||||||
|
|
||||||
// Utils
|
// Utils
|
||||||
// Create a buffer and allocate all the tensors in a ggml_context
|
// Create a buffer and allocate all the tensors in a ggml_context
|
||||||
|
// ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft
|
||||||
|
GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
|
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -307,6 +307,7 @@ extern "C" {
|
||||||
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
||||||
|
|
||||||
// Initialize backend buffers from a measure graph
|
// Initialize backend buffers from a measure graph
|
||||||
|
GGML_API void ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes);
|
||||||
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
|
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
|
||||||
|
|
||||||
GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
|
GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
|
||||||
|
|
|
||||||
|
|
@ -2615,6 +2615,7 @@ extern "C" {
|
||||||
|
|
||||||
// Set callback for all future logging events.
|
// Set callback for all future logging events.
|
||||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||||
|
GGML_API void ggml_log_get(ggml_log_callback * log_callback, void ** user_data);
|
||||||
GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data);
|
GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
|
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
|
||||||
|
|
|
||||||
|
|
@ -594,7 +594,9 @@ static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
|
static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
|
||||||
return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
|
return t->data != NULL // tensor data already set externally
|
||||||
|
|| t->buffer // tensor on external buffer (but not yet allocated)
|
||||||
|
|| ggml_gallocr_is_own(galloc, t); // tensor will be allocated by galloc
|
||||||
}
|
}
|
||||||
|
|
||||||
// free the extra space at the end if the new tensor is smaller
|
// free the extra space at the end if the new tensor is smaller
|
||||||
|
|
@ -823,7 +825,8 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
|
static bool ggml_gallocr_reserve_n_impl(
|
||||||
|
ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids, bool no_alloc) {
|
||||||
size_t min_hash_size = graph->n_nodes + graph->n_leafs;
|
size_t min_hash_size = graph->n_nodes + graph->n_leafs;
|
||||||
// add 25% margin to avoid hash collisions
|
// add 25% margin to avoid hash collisions
|
||||||
min_hash_size += min_hash_size / 4;
|
min_hash_size += min_hash_size / 4;
|
||||||
|
|
@ -928,12 +931,14 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||||
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
|
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
|
||||||
if (cur_size > 0) {
|
if (cur_size > 0) {
|
||||||
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n",
|
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n",
|
||||||
__func__, ggml_backend_buft_name(galloc->bufts[i]),
|
__func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||||
cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
ggml_vbuffer_free(galloc->buffers[i]);
|
ggml_vbuffer_free(galloc->buffers[i]);
|
||||||
|
if (no_alloc) {
|
||||||
|
galloc->buffers[i] = NULL;
|
||||||
|
} else {
|
||||||
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||||
if (galloc->buffers[i] == NULL) {
|
if (galloc->buffers[i] == NULL) {
|
||||||
GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
|
GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
|
||||||
|
|
@ -941,10 +946,26 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gallocr_reserve_n_size(
|
||||||
|
ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids, size_t * sizes) {
|
||||||
|
GGML_ASSERT(ggml_gallocr_reserve_n_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids, /*no_alloc =*/ true));
|
||||||
|
for (int i = 0; i < galloc->n_buffers; i++) {
|
||||||
|
sizes[i] = 0;
|
||||||
|
for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) {
|
||||||
|
sizes[i] += galloc->buf_tallocs[i]->chunks[c]->max_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
|
||||||
|
return ggml_gallocr_reserve_n_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids, /*no_alloc =*/ false);
|
||||||
|
}
|
||||||
|
|
||||||
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
|
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
|
||||||
return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL);
|
return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL);
|
||||||
}
|
}
|
||||||
|
|
@ -1147,7 +1168,8 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
|
static ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft_impl(
|
||||||
|
struct ggml_context * ctx, ggml_backend_buffer_type_t buft, size_t * nbytes_total, bool no_alloc) {
|
||||||
GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
|
GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
|
||||||
|
|
||||||
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
||||||
|
|
@ -1155,6 +1177,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
|
||||||
|
|
||||||
ggml_backend_buffer_t * buffers = NULL;
|
ggml_backend_buffer_t * buffers = NULL;
|
||||||
size_t n_buffers = 0;
|
size_t n_buffers = 0;
|
||||||
|
*nbytes_total = 0;
|
||||||
|
|
||||||
size_t cur_buf_size = 0;
|
size_t cur_buf_size = 0;
|
||||||
struct ggml_tensor * first = ggml_get_first_tensor(ctx);
|
struct ggml_tensor * first = ggml_get_first_tensor(ctx);
|
||||||
|
|
@ -1166,10 +1189,11 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
|
||||||
|
|
||||||
if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) {
|
if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) {
|
||||||
// allocate tensors in the current buffer
|
// allocate tensors in the current buffer
|
||||||
if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
|
if (!no_alloc && !alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
first = t;
|
first = t;
|
||||||
|
*nbytes_total += cur_buf_size;
|
||||||
cur_buf_size = this_size;
|
cur_buf_size = this_size;
|
||||||
} else {
|
} else {
|
||||||
cur_buf_size += this_size;
|
cur_buf_size += this_size;
|
||||||
|
|
@ -1178,15 +1202,21 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
|
||||||
|
|
||||||
// allocate remaining tensors
|
// allocate remaining tensors
|
||||||
if (cur_buf_size > 0) {
|
if (cur_buf_size > 0) {
|
||||||
if (!alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) {
|
*nbytes_total += cur_buf_size;
|
||||||
|
if (!no_alloc && !alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (no_alloc) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
if (n_buffers == 0) {
|
if (n_buffers == 0) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__);
|
GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__);
|
||||||
#endif
|
#endif
|
||||||
|
GGML_ASSERT(!buffers);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1196,10 +1226,24 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
|
||||||
} else {
|
} else {
|
||||||
buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers);
|
buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers);
|
||||||
}
|
}
|
||||||
free(buffers);
|
if (buffers) {
|
||||||
|
free(buffers); // can be NULL if context is empty or no_alloc
|
||||||
|
}
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
|
||||||
|
size_t nbytes_total = 0;
|
||||||
|
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc=*/ true);
|
||||||
|
GGML_ASSERT(!buf);
|
||||||
|
return nbytes_total;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
|
||||||
|
size_t nbytes_total = 0;
|
||||||
|
return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
|
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
|
||||||
return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
|
return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,12 +36,11 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
|
GGML_ASSERT(buft);
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
// return a dummy buffer for zero-sized allocations
|
// return a dummy buffer for zero-sized allocations
|
||||||
return ggml_backend_buffer_init(buft, {}, NULL, 0);
|
return ggml_backend_buffer_init(buft, {}, NULL, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(buft);
|
|
||||||
return buft->iface.alloc_buffer(buft, size);
|
return buft->iface.alloc_buffer(buft, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -128,6 +127,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME JG: a multi_buffer has a non-zero size, according to the above comment get_base is not optional,
|
||||||
|
// I don't know whether the above comment is correct
|
||||||
|
if (!buffer->iface.get_base) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
void * base = buffer->iface.get_base(buffer);
|
void * base = buffer->iface.get_base(buffer);
|
||||||
|
|
||||||
GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
|
GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
|
||||||
|
|
@ -1727,6 +1732,20 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
|
||||||
sched->is_alloc = false;
|
sched->is_alloc = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes) {
|
||||||
|
GGML_ASSERT(sched);
|
||||||
|
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
||||||
|
GGML_ASSERT(sizes);
|
||||||
|
|
||||||
|
ggml_backend_sched_reset(sched);
|
||||||
|
|
||||||
|
ggml_backend_sched_synchronize(sched);
|
||||||
|
|
||||||
|
ggml_backend_sched_split_graph(sched, measure_graph);
|
||||||
|
|
||||||
|
ggml_gallocr_reserve_n_size(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids, sizes);
|
||||||
|
}
|
||||||
|
|
||||||
bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
|
bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
|
||||||
GGML_ASSERT(sched);
|
GGML_ASSERT(sched);
|
||||||
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
||||||
|
|
|
||||||
|
|
@ -7566,6 +7566,11 @@ size_t ggml_quantize_chunk(
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void ggml_log_get(ggml_log_callback * log_callback, void ** user_data) {
|
||||||
|
*log_callback = g_logger_state.log_callback;
|
||||||
|
*user_data = g_logger_state.log_callback_user_data;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
|
void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||||
g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
|
g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
|
||||||
g_logger_state.log_callback_user_data = user_data;
|
g_logger_state.log_callback_user_data = user_data;
|
||||||
|
|
|
||||||
|
|
@ -313,6 +313,7 @@ extern "C" {
|
||||||
bool check_tensors; // validate model tensor data
|
bool check_tensors; // validate model tensor data
|
||||||
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
||||||
bool no_host; // bypass host buffer allowing extra buffers to be used
|
bool no_host; // bypass host buffer allowing extra buffers to be used
|
||||||
|
bool no_alloc; // only load metadata and simulate memory allocations
|
||||||
};
|
};
|
||||||
|
|
||||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||||
|
|
@ -466,10 +467,24 @@ extern "C" {
|
||||||
// Frees all allocated memory
|
// Frees all allocated memory
|
||||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||||
|
|
||||||
|
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
|
||||||
|
// returns true if the parameters could be successfully modified to fit device memory
|
||||||
|
// this function is NOT thread safe because it modifies the global llama logger state
|
||||||
|
LLAMA_API bool llama_params_fit(
|
||||||
|
const char * path_model,
|
||||||
|
struct llama_model_params * mparams,
|
||||||
|
struct llama_context_params * cparams,
|
||||||
|
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
|
||||||
|
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
|
||||||
|
size_t margin, // margin of memory to leave per device in bytes
|
||||||
|
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
|
||||||
|
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
|
||||||
|
|
||||||
LLAMA_API int64_t llama_time_us(void);
|
LLAMA_API int64_t llama_time_us(void);
|
||||||
|
|
||||||
LLAMA_API size_t llama_max_devices(void);
|
LLAMA_API size_t llama_max_devices(void);
|
||||||
LLAMA_API size_t llama_max_parallel_sequences(void);
|
LLAMA_API size_t llama_max_parallel_sequences(void);
|
||||||
|
LLAMA_API size_t llama_max_tensor_buft_overrides(void);
|
||||||
|
|
||||||
LLAMA_API bool llama_supports_mmap (void);
|
LLAMA_API bool llama_supports_mmap (void);
|
||||||
LLAMA_API bool llama_supports_mlock (void);
|
LLAMA_API bool llama_supports_mlock (void);
|
||||||
|
|
@ -1354,6 +1369,8 @@ extern "C" {
|
||||||
|
|
||||||
// Set callback for all future logging events.
|
// Set callback for all future logging events.
|
||||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||||
|
// The logger state is global so these functions are NOT thread safe.
|
||||||
|
LLAMA_API void llama_log_get(ggml_log_callback * log_callback, void ** user_data);
|
||||||
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -258,6 +258,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
backend_buft.clear();
|
backend_buft.clear();
|
||||||
backend_ptrs.clear();
|
backend_ptrs.clear();
|
||||||
|
backend_buf_exp_size.clear();
|
||||||
|
|
||||||
for (auto & backend : backends) {
|
for (auto & backend : backends) {
|
||||||
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
||||||
|
|
@ -274,6 +275,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
backend_buft.push_back(buft);
|
backend_buft.push_back(buft);
|
||||||
backend_ptrs.push_back(backend.get());
|
backend_ptrs.push_back(backend.get());
|
||||||
|
backend_buf_exp_size.push_back(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
||||||
|
|
@ -389,7 +391,8 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
|
||||||
|
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
if (pipeline_parallel) {
|
if (pipeline_parallel) {
|
||||||
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
|
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
|
||||||
|
|
@ -407,7 +410,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// reserve with tg (token generation) graph to get the number of splits and nodes
|
// reserve with tg (token generation) graph to get the number of splits and nodes
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||||
}
|
}
|
||||||
|
|
@ -422,7 +425,7 @@ llama_context::llama_context(
|
||||||
//
|
//
|
||||||
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
||||||
//
|
//
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
|
|
@ -431,11 +434,13 @@ llama_context::llama_context(
|
||||||
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||||
ggml_backend_t backend = backend_ptrs[i];
|
ggml_backend_t backend = backend_ptrs[i];
|
||||||
ggml_backend_buffer_type_t buft = backend_buft[i];
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||||
size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
if (!model.hparams.no_alloc) {
|
||||||
if (size > 1) {
|
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||||
|
}
|
||||||
|
if (backend_buf_exp_size[i] > 1) {
|
||||||
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
||||||
ggml_backend_buft_name(buft),
|
ggml_backend_buft_name(buft),
|
||||||
size / 1024.0 / 1024.0);
|
backend_buf_exp_size[i] / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -454,6 +459,23 @@ llama_context::llama_context(
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_context::~llama_context() {
|
llama_context::~llama_context() {
|
||||||
|
// FIXME this currently results in a use-after-free bug if the model is freed before the context
|
||||||
|
// if (!model.hparams.no_alloc) {
|
||||||
|
// for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||||
|
// ggml_backend_t backend = backend_ptrs[i];
|
||||||
|
// ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||||
|
|
||||||
|
// const size_t size_exp = backend_buf_exp_size[i];
|
||||||
|
// const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||||
|
// if (size_exp == size_act) {
|
||||||
|
// LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
|
||||||
|
// __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
||||||
|
// } else {
|
||||||
|
// LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
|
||||||
|
// __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
ggml_opt_free(opt_ctx);
|
ggml_opt_free(opt_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1428,7 +1450,8 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
|
||||||
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
|
ggml_cgraph * llama_context::graph_reserve(
|
||||||
|
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
|
||||||
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||||
GGML_ASSERT(n_outputs >= 1);
|
GGML_ASSERT(n_outputs >= 1);
|
||||||
|
|
||||||
|
|
@ -1465,8 +1488,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||||
|
|
||||||
// initialize scheduler with the specified graph
|
// initialize scheduler with the specified graph
|
||||||
if (split_only) {
|
if (split_only) {
|
||||||
|
if (sizes) {
|
||||||
|
ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
|
||||||
|
} else {
|
||||||
ggml_backend_sched_split_graph(sched.get(), gf);
|
ggml_backend_sched_split_graph(sched.get(), gf);
|
||||||
|
}
|
||||||
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||||
|
GGML_ASSERT(!sizes);
|
||||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
@ -2088,15 +2116,26 @@ void llama_context::perf_reset() {
|
||||||
|
|
||||||
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
|
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
|
||||||
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
|
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
|
||||||
for (const auto & buft_size : model.memory_breakdown()) {
|
for (const auto & [buft, size] : model.memory_breakdown()) {
|
||||||
ret[buft_size.first].model += buft_size.second;
|
ret[buft].model += size;
|
||||||
}
|
}
|
||||||
for (const auto & buft_size : memory->memory_breakdown()) {
|
if (memory) {
|
||||||
ret[buft_size.first].context += buft_size.second;
|
for (const auto & [buft, size] : memory->memory_breakdown()) {
|
||||||
|
ret[buft].context += size;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if (model.hparams.no_alloc) {
|
||||||
|
for (size_t i = 0; i < backends.size(); ++i) {
|
||||||
|
ggml_backend_t backend = backends[i].get();
|
||||||
|
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
||||||
|
ret[buft].compute += backend_buf_exp_size[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for (const auto & backend_ptr : backends) {
|
for (const auto & backend_ptr : backends) {
|
||||||
ggml_backend_t backend = backend_ptr.get();
|
ggml_backend_t backend = backend_ptr.get();
|
||||||
ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
||||||
|
ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,10 @@ struct llama_memory_breakdown_data {
|
||||||
size_t model = 0; // memory allocated for the model
|
size_t model = 0; // memory allocated for the model
|
||||||
size_t context = 0; // memory allocated for the context
|
size_t context = 0; // memory allocated for the context
|
||||||
size_t compute = 0; // memory allocated for temporary compute buffers
|
size_t compute = 0; // memory allocated for temporary compute buffers
|
||||||
|
|
||||||
|
size_t total() const {
|
||||||
|
return model + context + compute;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_context {
|
struct llama_context {
|
||||||
|
|
@ -206,7 +210,8 @@ public:
|
||||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||||
|
|
||||||
// reserve a graph with a dummy ubatch of the specified size
|
// reserve a graph with a dummy ubatch of the specified size
|
||||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
|
ggml_cgraph * graph_reserve(
|
||||||
|
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llm_graph_params graph_params(
|
llm_graph_params graph_params(
|
||||||
|
|
@ -281,9 +286,10 @@ private:
|
||||||
|
|
||||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||||
|
|
||||||
// buffer types used for the compute buffer of each backend
|
// pointers and buffer types used for the compute buffer of each backend
|
||||||
std::vector<ggml_backend_t> backend_ptrs;
|
std::vector<ggml_backend_t> backend_ptrs;
|
||||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||||
|
std::vector<size_t> backend_buf_exp_size; // expected buffer sizes
|
||||||
|
|
||||||
llm_graph_result_ptr gf_res_prev;
|
llm_graph_result_ptr gf_res_prev;
|
||||||
llm_graph_result_ptr gf_res_reserve;
|
llm_graph_result_ptr gf_res_reserve;
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ struct llama_hparams_convnext {
|
||||||
|
|
||||||
struct llama_hparams {
|
struct llama_hparams {
|
||||||
bool vocab_only;
|
bool vocab_only;
|
||||||
|
bool no_alloc;
|
||||||
bool rope_finetuned;
|
bool rope_finetuned;
|
||||||
bool use_par_res;
|
bool use_par_res;
|
||||||
bool swin_norm;
|
bool swin_norm;
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,10 @@ time_meas::~time_meas() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_log_get(ggml_log_callback * log_callback, void ** user_data) {
|
||||||
|
ggml_log_get(log_callback, user_data);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_log_set(ggml_log_callback log_callback, void * user_data) {
|
void llama_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||||
ggml_log_set(log_callback, user_data);
|
ggml_log_set(log_callback, user_data);
|
||||||
g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
|
g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
|
||||||
|
|
|
||||||
|
|
@ -175,7 +175,15 @@ llama_kv_cache::llama_kv_cache(
|
||||||
|
|
||||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||||
for (auto & [buft, ctx] : ctx_map) {
|
for (auto & [buft, ctx] : ctx_map) {
|
||||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
|
ggml_backend_buffer_t buf;
|
||||||
|
if (model.hparams.no_alloc) {
|
||||||
|
buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
|
||||||
|
t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer
|
||||||
|
}
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
throw std::runtime_error("failed to allocate buffer for kv cache");
|
throw std::runtime_error("failed to allocate buffer for kv cache");
|
||||||
}
|
}
|
||||||
|
|
@ -482,9 +490,18 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
|
|
||||||
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
|
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
|
||||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||||
for (const auto & [_, buf] : ctxs_bufs) {
|
for (const auto & [ctx, buf] : ctxs_bufs) {
|
||||||
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf.get());
|
||||||
|
|
||||||
|
if (hparams.no_alloc) {
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) == nullptr);
|
||||||
|
ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft);
|
||||||
|
} else {
|
||||||
|
// GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base
|
||||||
|
ret[buft] += ggml_backend_buffer_get_size(buf.get());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -473,6 +473,7 @@ llama_model_loader::llama_model_loader(
|
||||||
std::vector<std::string> & splits,
|
std::vector<std::string> & splits,
|
||||||
bool use_mmap,
|
bool use_mmap,
|
||||||
bool check_tensors,
|
bool check_tensors,
|
||||||
|
bool no_alloc,
|
||||||
const llama_model_kv_override * param_overrides_p,
|
const llama_model_kv_override * param_overrides_p,
|
||||||
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
|
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
|
||||||
int trace = 0;
|
int trace = 0;
|
||||||
|
|
@ -716,6 +717,7 @@ llama_model_loader::llama_model_loader(
|
||||||
|
|
||||||
this->use_mmap = use_mmap;
|
this->use_mmap = use_mmap;
|
||||||
this->check_tensors = check_tensors;
|
this->check_tensors = check_tensors;
|
||||||
|
this->no_alloc = no_alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llama_model_loader::get_arch_name() const {
|
std::string llama_model_loader::get_arch_name() const {
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,7 @@ struct llama_model_loader {
|
||||||
|
|
||||||
bool use_mmap = false;
|
bool use_mmap = false;
|
||||||
bool check_tensors;
|
bool check_tensors;
|
||||||
|
bool no_alloc;
|
||||||
|
|
||||||
llama_files files;
|
llama_files files;
|
||||||
llama_ftype ftype;
|
llama_ftype ftype;
|
||||||
|
|
@ -97,6 +98,7 @@ struct llama_model_loader {
|
||||||
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
||||||
bool use_mmap,
|
bool use_mmap,
|
||||||
bool check_tensors,
|
bool check_tensors,
|
||||||
|
bool no_alloc,
|
||||||
const llama_model_kv_override * param_overrides_p,
|
const llama_model_kv_override * param_overrides_p,
|
||||||
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
|
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6606,9 +6606,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
|
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||||
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
|
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
|
||||||
|
GGML_ASSERT(!ml.no_alloc);
|
||||||
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
||||||
// only the mmap region containing the tensors in the model is mapped to the backend buffer
|
// only the mmap region containing the tensors in the model is mapped to the backend buffer
|
||||||
// this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
|
// this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer,
|
||||||
|
// then we could just use metal for all layers
|
||||||
// this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
|
// this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
|
||||||
void * addr = nullptr;
|
void * addr = nullptr;
|
||||||
size_t first, last; // NOLINT
|
size_t first, last; // NOLINT
|
||||||
|
|
@ -6624,9 +6626,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
bufs.emplace_back(buf);
|
bufs.emplace_back(buf);
|
||||||
buf_map.emplace(idx, buf);
|
buf_map.emplace(idx, buf);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
ggml_backend_buffer_t buf;
|
||||||
|
if (ml.no_alloc) {
|
||||||
|
buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
|
||||||
if (buf == nullptr) {
|
if (buf == nullptr) {
|
||||||
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
||||||
}
|
}
|
||||||
|
|
@ -6681,6 +6690,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ml.no_alloc) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// load tensor data
|
// load tensor data
|
||||||
for (auto & [ctx, buf_map] : ctx_buf_maps) {
|
for (auto & [ctx, buf_map] : ctx_buf_maps) {
|
||||||
if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
|
if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
|
||||||
|
|
@ -6723,11 +6736,20 @@ size_t llama_model::n_devices() const {
|
||||||
|
|
||||||
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
|
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
|
||||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||||
for (const auto & [_, bufs] : pimpl->ctxs_bufs) {
|
for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) {
|
||||||
|
if (hparams.no_alloc) {
|
||||||
|
GGML_ASSERT(bufs.size() == 1);
|
||||||
|
ggml_backend_buffer_t buf = bufs[0].get();
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_get_base(buf) == nullptr);
|
||||||
|
ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf);
|
||||||
|
ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft);
|
||||||
|
} else {
|
||||||
for (const auto & buf : bufs) {
|
for (const auto & buf : bufs) {
|
||||||
|
// GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base
|
||||||
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -6770,6 +6792,7 @@ void llama_model::print_info() const {
|
||||||
// hparams
|
// hparams
|
||||||
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str());
|
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str());
|
||||||
LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only);
|
LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only);
|
||||||
|
LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc);
|
||||||
|
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
|
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
|
||||||
|
|
@ -7618,6 +7641,7 @@ llama_model_params llama_model_default_params() {
|
||||||
/*.check_tensors =*/ false,
|
/*.check_tensors =*/ false,
|
||||||
/*.use_extra_bufts =*/ true,
|
/*.use_extra_bufts =*/ true,
|
||||||
/*.no_host =*/ false,
|
/*.no_host =*/ false,
|
||||||
|
/*.no_alloc =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
|
||||||
|
|
@ -596,7 +596,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> splits = {};
|
std::vector<std::string> splits = {};
|
||||||
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
|
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
|
||||||
ml.init_mappings(false); // no prefetching
|
ml.init_mappings(false); // no prefetching
|
||||||
|
|
||||||
llama_model model(llama_model_default_params());
|
llama_model model(llama_model_default_params());
|
||||||
|
|
|
||||||
650
src/llama.cpp
650
src/llama.cpp
|
|
@ -1,6 +1,9 @@
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
|
|
||||||
#include "llama-chat.h"
|
#include "llama-chat.h"
|
||||||
|
#include "llama-context.h"
|
||||||
#include "llama-mmap.h"
|
#include "llama-mmap.h"
|
||||||
#include "llama-vocab.h"
|
#include "llama-vocab.h"
|
||||||
#include "llama-model-loader.h"
|
#include "llama-model-loader.h"
|
||||||
|
|
@ -11,11 +14,14 @@
|
||||||
#include "ggml-backend.h"
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cinttypes>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
|
|
@ -37,6 +43,643 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_device_memory_data {
|
||||||
|
int64_t total;
|
||||||
|
int64_t free;
|
||||||
|
llama_memory_breakdown_data mb;
|
||||||
|
};
|
||||||
|
|
||||||
|
static std::vector<llama_device_memory_data> llama_get_device_memory_data(
|
||||||
|
const char * path_model, const llama_model_params * mparams, const llama_context_params * cparams,
|
||||||
|
std::vector<ggml_backend_dev_t> & devs, uint32_t & hp_ngl, uint32_t & hp_n_ctx_train, uint32_t & hp_n_expert,
|
||||||
|
const ggml_log_level log_level) {
|
||||||
|
struct user_data_t {
|
||||||
|
struct {
|
||||||
|
ggml_log_callback callback;
|
||||||
|
void * user_data;
|
||||||
|
} original_logger;
|
||||||
|
ggml_log_level min_level; // prints below this log level go to debug log
|
||||||
|
};
|
||||||
|
user_data_t ud;
|
||||||
|
llama_log_get(&ud.original_logger.callback, &ud.original_logger.user_data);
|
||||||
|
ud.min_level = log_level;
|
||||||
|
|
||||||
|
llama_log_set([](ggml_log_level level, const char * text, void * user_data) {
|
||||||
|
const user_data_t * ud = (const user_data_t *) user_data;
|
||||||
|
const ggml_log_level level_eff = level >= ud->min_level ? level : GGML_LOG_LEVEL_DEBUG;
|
||||||
|
ud->original_logger.callback(level_eff, text, ud->original_logger.user_data);
|
||||||
|
}, &ud);
|
||||||
|
|
||||||
|
llama_model_params mparams_copy = *mparams;
|
||||||
|
mparams_copy.no_alloc = true;
|
||||||
|
mparams_copy.use_mmap = false;
|
||||||
|
|
||||||
|
llama_model * model = llama_model_load_from_file(path_model, mparams_copy);
|
||||||
|
if (model == nullptr) {
|
||||||
|
llama_log_set(ud.original_logger.callback, ud.original_logger.user_data);
|
||||||
|
throw std::runtime_error("failed to load model");
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_context * ctx = llama_init_from_model(model, *cparams);
|
||||||
|
if (ctx == nullptr) {
|
||||||
|
llama_model_free(model);
|
||||||
|
llama_log_set(ud.original_logger.callback, ud.original_logger.user_data);
|
||||||
|
throw std::runtime_error("failed to create llama_context from model");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_device_memory_data> ret(model->devices.size());
|
||||||
|
|
||||||
|
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
|
||||||
|
|
||||||
|
for (const auto & [buft, mb] : memory_breakdown) {
|
||||||
|
if (ggml_backend_buft_is_host(buft)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
|
||||||
|
if (!dev) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < ret.size(); i++) {
|
||||||
|
if (model->devices[i] == dev) {
|
||||||
|
ret[i].mb.model += mb.model;
|
||||||
|
ret[i].mb.context += mb.context;
|
||||||
|
ret[i].mb.compute += mb.compute;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < ret.size(); i++) {
|
||||||
|
size_t free, total;
|
||||||
|
ggml_backend_dev_memory(model->devices[i], &free, &total);
|
||||||
|
ret[i].free = free;
|
||||||
|
ret[i].total = total;
|
||||||
|
}
|
||||||
|
|
||||||
|
devs = model->devices;
|
||||||
|
hp_ngl = model->hparams.n_layer;
|
||||||
|
hp_n_ctx_train = model->hparams.n_ctx_train;
|
||||||
|
hp_n_expert = model->hparams.n_expert;
|
||||||
|
|
||||||
|
llama_memory_breakdown_print(ctx); // goes to debug log
|
||||||
|
|
||||||
|
llama_free(ctx);
|
||||||
|
llama_model_free(model);
|
||||||
|
llama_log_set(ud.original_logger.callback, ud.original_logger.user_data);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
// enum to identify part of a layer for distributing its tensors:
|
||||||
|
enum layer_fraction_t {
|
||||||
|
LAYER_FRACTION_NONE = 0, // nothing
|
||||||
|
LAYER_FRACTION_ATTN = 1, // attention
|
||||||
|
LAYER_FRACTION_UP = 2, // attention + up
|
||||||
|
LAYER_FRACTION_GATE = 3, // attention + up + gate
|
||||||
|
LAYER_FRACTION_MOE = 4, // everything but sparse MoE weights
|
||||||
|
};
|
||||||
|
// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue
|
||||||
|
|
||||||
|
static void llama_params_fit_impl(
|
||||||
|
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
|
||||||
|
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
|
||||||
|
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
|
||||||
|
constexpr int64_t MiB = 1024*1024;
|
||||||
|
const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
|
||||||
|
typedef std::vector<llama_device_memory_data> dmds_t;
|
||||||
|
const llama_model_params default_mparams = llama_model_default_params();
|
||||||
|
|
||||||
|
std::vector<ggml_backend_dev_t> devs;
|
||||||
|
uint32_t hp_ngl = 0; // hparams.n_gpu_layers
|
||||||
|
uint32_t hp_nct = 0; // hparams.n_ctx_train
|
||||||
|
uint32_t hp_nex = 0; // hparams.n_expert
|
||||||
|
|
||||||
|
// step 1: get data for default parameters and check whether any changes are necessary in the first place
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: getting device memory data for initial parameters:\n", __func__);
|
||||||
|
const dmds_t dmds_full = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
|
||||||
|
const size_t nd = devs.size(); // number of devices
|
||||||
|
if (nd == 0) {
|
||||||
|
LLAMA_LOG_INFO("%s: no devices with dedicated memory found\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> dev_names;
|
||||||
|
{
|
||||||
|
dev_names.reserve(nd);
|
||||||
|
size_t max_length = 0;
|
||||||
|
for (ggml_backend_dev_t dev : devs) {
|
||||||
|
std::string name = ggml_backend_dev_name(dev);
|
||||||
|
name += " (";
|
||||||
|
name += ggml_backend_dev_description(dev);
|
||||||
|
name += ")";
|
||||||
|
dev_names.push_back(name);
|
||||||
|
max_length = std::max(max_length, name.length());
|
||||||
|
}
|
||||||
|
for (std::string & dn : dev_names) {
|
||||||
|
dn.insert(dn.end(), max_length - dn.length(), ' ');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t sum_total = 0;
|
||||||
|
int64_t sum_projected_free = 0;
|
||||||
|
int64_t min_projected_free = INT64_MAX;
|
||||||
|
int64_t sum_projected_used = 0;
|
||||||
|
int64_t sum_projected_ctx = 0;
|
||||||
|
|
||||||
|
if (nd > 1) {
|
||||||
|
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
|
||||||
|
}
|
||||||
|
for (size_t id = 0; id < nd; id++) {
|
||||||
|
const llama_device_memory_data & dmd = dmds_full[id];
|
||||||
|
|
||||||
|
const int64_t projected_used = dmd.mb.total();
|
||||||
|
const int64_t projected_free = dmd.free - projected_used;
|
||||||
|
|
||||||
|
sum_total += dmd.total;
|
||||||
|
sum_projected_used += projected_used;
|
||||||
|
sum_projected_free += projected_free;
|
||||||
|
min_projected_free = std::min(min_projected_free, projected_free);
|
||||||
|
sum_projected_ctx += dmd.mb.context;
|
||||||
|
|
||||||
|
if (nd > 1) {
|
||||||
|
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
|
||||||
|
__func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB,
|
||||||
|
projected_free >= 0 ? "surplus" : "deficit");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert(sum_total >= 0 && sum_projected_used >= 0 && sum_projected_ctx >= 0);
|
||||||
|
assert(sum_projected_used >= sum_projected_ctx);
|
||||||
|
LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
|
||||||
|
__func__, sum_projected_used/MiB, sum_total/MiB);
|
||||||
|
if (min_projected_free >= margin) {
|
||||||
|
if (nd == 1) {
|
||||||
|
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
|
||||||
|
__func__, min_projected_free/MiB, margin/MiB);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n",
|
||||||
|
__func__, min_projected_free/MiB, margin/MiB);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// step 2: try reducing memory use by reducing the context size
|
||||||
|
|
||||||
|
{
|
||||||
|
int64_t global_surplus = sum_projected_free - int64_t(nd)*margin;
|
||||||
|
if (global_surplus < 0) {
|
||||||
|
LLAMA_LOG_INFO(nd == 1 ?
|
||||||
|
"%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" :
|
||||||
|
"%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n",
|
||||||
|
__func__, margin/MiB, -global_surplus/MiB);
|
||||||
|
if (cparams->n_ctx == 0) {
|
||||||
|
if (hp_nct > n_ctx_min) {
|
||||||
|
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
|
||||||
|
const uint32_t ctx_reduction = std::min(
|
||||||
|
uint32_t((-global_surplus + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
|
||||||
|
cparams->n_ctx = hp_nct - ctx_reduction;
|
||||||
|
const int64_t memory_reduction = ctx_reduction * bytes_per_ctx;
|
||||||
|
global_surplus += memory_reduction;
|
||||||
|
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
|
||||||
|
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
|
||||||
|
} else {
|
||||||
|
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
|
||||||
|
__func__, hp_nct, n_ctx_min);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (global_surplus >= 0) {
|
||||||
|
LLAMA_LOG_INFO("%s: entire model can be fit across devices by reducing context\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) {
|
||||||
|
throw std::runtime_error("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
|
||||||
|
}
|
||||||
|
if (nd > 1) {
|
||||||
|
if (!tensor_split) {
|
||||||
|
throw std::runtime_error("did not provide a buffer to write the tensor_split to, abort");
|
||||||
|
}
|
||||||
|
if (mparams->tensor_split) {
|
||||||
|
for (size_t id = 0; id < nd; id++) {
|
||||||
|
if (mparams->tensor_split[id] != 0.0f) {
|
||||||
|
throw std::runtime_error("model_params::tensor_split already set by user, abort");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||||
|
throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
|
||||||
|
}
|
||||||
|
if (hp_ngl < 2*nd) {
|
||||||
|
throw std::runtime_error("model has only " + std::to_string(hp_ngl) + " layers but need at least "
|
||||||
|
+ std::to_string(2*nd) + " to fit memory for " + std::to_string(nd) + " devices, abort");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!tensor_buft_overrides) {
|
||||||
|
throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort");
|
||||||
|
}
|
||||||
|
if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) {
|
||||||
|
throw std::runtime_error("model_params::tensor_buft_overrides already set by user, abort");
|
||||||
|
}
|
||||||
|
|
||||||
|
// step 3: iteratively fill the back to front with "dense" layers
|
||||||
|
// - for a dense model simply fill full layers, giving each device a contiguous slice of the model
|
||||||
|
// - for a MoE model, same as dense model but with all MoE tensors in system memory
|
||||||
|
|
||||||
|
// utility function that returns a static C string matching the tensors for a specific layer index and layer fraction:
|
||||||
|
auto get_overflow_pattern = [&](const size_t il, const layer_fraction_t lf) -> const char * {
|
||||||
|
constexpr size_t n_strings = 1000;
|
||||||
|
if (il >= n_strings) {
|
||||||
|
throw std::runtime_error("at most " + std::to_string(n_strings) + " model layers are supported");
|
||||||
|
}
|
||||||
|
switch (lf) {
|
||||||
|
case LAYER_FRACTION_ATTN: {
|
||||||
|
static std::array<std::string, n_strings> patterns;
|
||||||
|
if (patterns[il].empty()) {
|
||||||
|
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|gate|down).*";
|
||||||
|
}
|
||||||
|
return patterns[il].c_str();
|
||||||
|
}
|
||||||
|
case LAYER_FRACTION_UP: {
|
||||||
|
static std::array<std::string, n_strings> patterns;
|
||||||
|
if (patterns[il].empty()) {
|
||||||
|
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|down).*";
|
||||||
|
}
|
||||||
|
return patterns[il].c_str();
|
||||||
|
}
|
||||||
|
case LAYER_FRACTION_GATE: {
|
||||||
|
static std::array<std::string, n_strings> patterns;
|
||||||
|
if (patterns[il].empty()) {
|
||||||
|
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_down.*";
|
||||||
|
}
|
||||||
|
return patterns[il].c_str();
|
||||||
|
}
|
||||||
|
case LAYER_FRACTION_MOE: {
|
||||||
|
static std::array<std::string, n_strings> patterns;
|
||||||
|
if (patterns[il].empty()) {
|
||||||
|
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate)_(ch|)exps";
|
||||||
|
}
|
||||||
|
return patterns[il].c_str();
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ngl_t {
|
||||||
|
uint32_t n_layer = 0; // number of total layers
|
||||||
|
uint32_t n_part = 0; // number of partial layers, <= n_layer
|
||||||
|
|
||||||
|
// for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE:
|
||||||
|
layer_fraction_t overflow_type = LAYER_FRACTION_MOE;
|
||||||
|
};
|
||||||
|
|
||||||
|
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||||
|
|
||||||
|
// utility function to set n_gpu_layers and tensor_split
|
||||||
|
auto set_ngl_tensor_split_tbo = [&](
|
||||||
|
const std::vector<ngl_t> & ngl_per_device,
|
||||||
|
const std::vector<ggml_backend_buffer_type_t> & overflow_bufts,
|
||||||
|
llama_model_params & mparams,
|
||||||
|
const bool add_nonrepeating) {
|
||||||
|
mparams.n_gpu_layers = 0;
|
||||||
|
for (size_t id = 0; id < nd; id++) {
|
||||||
|
mparams.n_gpu_layers += ngl_per_device[id].n_layer;
|
||||||
|
if (nd > 1) {
|
||||||
|
tensor_split[id] = ngl_per_device[id].n_layer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl);
|
||||||
|
uint32_t il0 = hp_ngl - mparams.n_gpu_layers; // start index for tensor buft overrides
|
||||||
|
|
||||||
|
if (add_nonrepeating) {
|
||||||
|
mparams.n_gpu_layers += 1;
|
||||||
|
tensor_split[nd - 1] += 1;
|
||||||
|
}
|
||||||
|
mparams.tensor_split = tensor_split;
|
||||||
|
|
||||||
|
size_t itbo = 0;
|
||||||
|
for (size_t id = 0; id < nd; id++) {
|
||||||
|
il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part;
|
||||||
|
for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) {
|
||||||
|
if (itbo + 1 >= ntbo) {
|
||||||
|
tensor_buft_overrides[itbo].pattern = nullptr;
|
||||||
|
tensor_buft_overrides[itbo].buft = nullptr;
|
||||||
|
itbo++;
|
||||||
|
mparams.tensor_buft_overrides = tensor_buft_overrides;
|
||||||
|
throw std::runtime_error("llama_params_fit_n_tensor_buft_overrides() == "
|
||||||
|
+ std::to_string(ntbo) + " is insufficient for model\n");
|
||||||
|
}
|
||||||
|
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
|
||||||
|
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
|
||||||
|
itbo++;
|
||||||
|
}
|
||||||
|
il0 += ngl_per_device[id].n_part;
|
||||||
|
}
|
||||||
|
tensor_buft_overrides[itbo].pattern = nullptr;
|
||||||
|
tensor_buft_overrides[itbo].buft = nullptr;
|
||||||
|
itbo++;
|
||||||
|
mparams.tensor_buft_overrides = tensor_buft_overrides;
|
||||||
|
};
|
||||||
|
|
||||||
|
// utility function that returns the memory use per device for given numbers of layers per device
|
||||||
|
auto get_memory_for_layers = [&](
|
||||||
|
const char * func_name,
|
||||||
|
const std::vector<ngl_t> & ngl_per_device,
|
||||||
|
const std::vector<ggml_backend_buffer_type_t> & overflow_bufts,
|
||||||
|
const bool add_nonrepeating) -> std::vector<int64_t> {
|
||||||
|
llama_model_params mparams_copy = *mparams;
|
||||||
|
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy, add_nonrepeating);
|
||||||
|
|
||||||
|
const dmds_t dmd_nl = llama_get_device_memory_data(
|
||||||
|
path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: memory for test allocation by device:\n", func_name);
|
||||||
|
for (size_t id = 0; id < nd; id++) {
|
||||||
|
const ngl_t & n = ngl_per_device[id];
|
||||||
|
LLAMA_LOG_DEBUG(
|
||||||
|
"%s: id=%zu, n_layer=%2" PRIu32 ", n_part=%2" PRIu32 ", overflow_type=%d, mem=%6" PRId64 " MiB\n",
|
||||||
|
func_name, id, n.n_layer, n.n_part, int(n.overflow_type), dmd_nl[id].mb.total()/MiB);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> ret;
|
||||||
|
ret.reserve(nd);
|
||||||
|
for (const llama_device_memory_data & dmd : dmd_nl) {
|
||||||
|
ret.push_back(dmd.mb.total());
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
int64_t global_surplus_cpu_moe = 0;
|
||||||
|
if (hp_nex > 0) {
|
||||||
|
const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate)_(ch|)exps"; // matches all MoE tensors
|
||||||
|
ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type();
|
||||||
|
tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft};
|
||||||
|
tensor_buft_overrides[1] = {nullptr, nullptr};
|
||||||
|
mparams->tensor_buft_overrides = tensor_buft_overrides;
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: getting device memory data with all MoE tensors moved to system memory:\n", __func__);
|
||||||
|
const dmds_t dmds_cpu_moe = llama_get_device_memory_data(
|
||||||
|
path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
|
||||||
|
|
||||||
|
for (const llama_device_memory_data & dmd : dmds_cpu_moe) {
|
||||||
|
global_surplus_cpu_moe += dmd.free;
|
||||||
|
global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (global_surplus_cpu_moe > 0) {
|
||||||
|
LLAMA_LOG_INFO("%s: with only dense weights in device memory there is a total surplus of %" PRId64 " MiB\n",
|
||||||
|
__func__, global_surplus_cpu_moe/MiB);
|
||||||
|
} else {
|
||||||
|
LLAMA_LOG_INFO("%s: with only dense weights in device memory there is still a total deficit of %" PRId64 " MiB\n",
|
||||||
|
__func__, -global_surplus_cpu_moe/MiB);
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset
|
||||||
|
tensor_buft_overrides[0] = {nullptr, nullptr};
|
||||||
|
mparams->tensor_buft_overrides = tensor_buft_overrides;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> targets; // maximum acceptable memory use per device
|
||||||
|
targets.reserve(nd);
|
||||||
|
for (size_t id = 0; id < nd; id++) {
|
||||||
|
targets.push_back(dmds_full[id].free - margin);
|
||||||
|
LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
|
||||||
|
}
|
||||||
|
|
||||||
|
// whether for the optimal memory use we expect to load at least some MoE tensors:
|
||||||
|
const bool partial_moe = hp_nex > 0 && global_surplus_cpu_moe > 0;
|
||||||
|
|
||||||
|
std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the partial layers of a device overflow to:
|
||||||
|
overflow_bufts.reserve(nd);
|
||||||
|
for (size_t id = 0; id < nd - 1; ++id) {
|
||||||
|
overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1]));
|
||||||
|
}
|
||||||
|
overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
|
||||||
|
|
||||||
|
std::vector<ngl_t> ngl_per_device(nd);
|
||||||
|
std::vector<int64_t> mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts, partial_moe);
|
||||||
|
if (hp_nex > 0) {
|
||||||
|
for (size_t id = 0; id < nd; id++) {
|
||||||
|
ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// optimize the number of layers per device using the method of false position:
|
||||||
|
// - ngl_per_device has 0 layers for each device, lower bound
|
||||||
|
// - try a "high" configuration where a device is given all unassigned layers
|
||||||
|
// - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target
|
||||||
|
// - check memory use of our guess, replace either the low or high bound
|
||||||
|
// - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits
|
||||||
|
if (hp_nex == 0) {
|
||||||
|
LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__);
|
||||||
|
} else {
|
||||||
|
LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__);
|
||||||
|
}
|
||||||
|
uint32_t n_unassigned = hp_ngl;
|
||||||
|
for (int id = nd - 1; id >= 0; id--) {
|
||||||
|
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
||||||
|
ngl_per_device_high[id].n_layer = n_unassigned;
|
||||||
|
if (hp_nex > 0) {
|
||||||
|
ngl_per_device_high[id].n_part = ngl_per_device_high[id].n_layer;
|
||||||
|
}
|
||||||
|
if (ngl_per_device_high[id].n_layer > 0) {
|
||||||
|
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
|
||||||
|
if (mem_high[id] > targets[id]) {
|
||||||
|
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
||||||
|
while (delta > 1) {
|
||||||
|
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
||||||
|
step_size = std::max(step_size, uint32_t(1));
|
||||||
|
step_size = std::min(step_size, delta - 1);
|
||||||
|
|
||||||
|
std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
|
||||||
|
ngl_per_device_test[id].n_layer += step_size;
|
||||||
|
if (hp_nex) {
|
||||||
|
ngl_per_device_test[id].n_part += step_size;
|
||||||
|
}
|
||||||
|
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||||
|
|
||||||
|
if (mem_test[id] <= targets[id]) {
|
||||||
|
ngl_per_device = ngl_per_device_test;
|
||||||
|
mem = mem_test;
|
||||||
|
n_unassigned -= ngl_per_device[id].n_layer;
|
||||||
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
||||||
|
} else {
|
||||||
|
ngl_per_device_high = ngl_per_device_test;
|
||||||
|
mem_high = mem_test;
|
||||||
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
||||||
|
}
|
||||||
|
delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ngl_per_device = ngl_per_device_high;
|
||||||
|
n_unassigned -= ngl_per_device[id].n_layer;
|
||||||
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t projected_margin = dmds_full[id].free - mem[id];
|
||||||
|
LLAMA_LOG_INFO(
|
||||||
|
"%s: - %s: %2" PRIu32 " layers, %6" PRId64 " MiB used, %6" PRId64 " MiB free\n",
|
||||||
|
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB);
|
||||||
|
}
|
||||||
|
if (hp_nex == 0 || global_surplus_cpu_moe <= 0) {
|
||||||
|
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// step 4: for a MoE model where all dense tensors fit,
|
||||||
|
// convert the dense-only layers in the back to full layers in the front until all devices are full
|
||||||
|
// essentially the same procedure as for the dense-only layers except front-to-back
|
||||||
|
// also, try fitting at least part of one more layer to reduce waste for "small" GPUs with e.g. 24 GiB VRAM
|
||||||
|
|
||||||
|
size_t id_dense_start = nd;
|
||||||
|
for (int id = nd - 1; id >= 0; id--) {
|
||||||
|
if (ngl_per_device[id].n_layer > 0) {
|
||||||
|
id_dense_start = id;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
assert(id_dense_start < nd);
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__);
|
||||||
|
for (size_t id = 0; id <= id_dense_start; id++) {
|
||||||
|
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
||||||
|
for (size_t jd = id_dense_start; jd < nd; jd++) {
|
||||||
|
const uint32_t n_layer_move = ngl_per_device_high[jd].n_layer;
|
||||||
|
ngl_per_device_high[id].n_layer += n_layer_move;
|
||||||
|
ngl_per_device_high[jd].n_layer -= n_layer_move;
|
||||||
|
ngl_per_device_high[jd].n_part = 0;
|
||||||
|
}
|
||||||
|
size_t id_dense_start_high = nd - 1;
|
||||||
|
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
|
||||||
|
|
||||||
|
if (mem_high[id] > targets[id]) {
|
||||||
|
assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part);
|
||||||
|
assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part);
|
||||||
|
assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
|
||||||
|
>= ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
|
||||||
|
uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
|
||||||
|
- (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
|
||||||
|
while (delta > 1) {
|
||||||
|
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
||||||
|
step_size = std::max(step_size, uint32_t(1));
|
||||||
|
step_size = std::min(step_size, delta - 1);
|
||||||
|
|
||||||
|
std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
|
||||||
|
size_t id_dense_start_test = id_dense_start;
|
||||||
|
uint32_t n_converted_test = 0;
|
||||||
|
for (;id_dense_start_test < nd; id_dense_start_test++) {
|
||||||
|
const uint32_t n_convert_jd = std::min(step_size - n_converted_test, ngl_per_device_test[id_dense_start_test].n_part);
|
||||||
|
ngl_per_device_test[id_dense_start_test].n_layer -= n_convert_jd;
|
||||||
|
ngl_per_device_test[id_dense_start_test].n_part -= n_convert_jd;
|
||||||
|
ngl_per_device_test[id].n_layer += n_convert_jd;
|
||||||
|
n_converted_test += n_convert_jd;
|
||||||
|
|
||||||
|
if (ngl_per_device_test[id_dense_start_test].n_layer > 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||||
|
|
||||||
|
if (mem_test[id] <= targets[id]) {
|
||||||
|
ngl_per_device = ngl_per_device_test;
|
||||||
|
mem = mem_test;
|
||||||
|
id_dense_start = id_dense_start_test;
|
||||||
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n",
|
||||||
|
__func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start);
|
||||||
|
} else {
|
||||||
|
ngl_per_device_high = ngl_per_device_test;
|
||||||
|
mem_high = mem_test;
|
||||||
|
id_dense_start_high = id_dense_start_test;
|
||||||
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n",
|
||||||
|
__func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high);
|
||||||
|
}
|
||||||
|
delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
|
||||||
|
- (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ngl_per_device = ngl_per_device_high;
|
||||||
|
id_dense_start = id_dense_start_high;
|
||||||
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n",
|
||||||
|
__func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start);
|
||||||
|
}
|
||||||
|
|
||||||
|
// try to fit at least part of one more layer
|
||||||
|
if (ngl_per_device[id_dense_start].n_layer > 0) {
|
||||||
|
std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
|
||||||
|
size_t id_dense_start_test = id_dense_start;
|
||||||
|
ngl_per_device_test[id_dense_start_test].n_layer--;
|
||||||
|
ngl_per_device_test[id_dense_start_test].n_part--;
|
||||||
|
ngl_per_device_test[id].n_layer++;
|
||||||
|
ngl_per_device_test[id].n_part++;
|
||||||
|
if (ngl_per_device_test[id_dense_start_test].n_layer == 0) {
|
||||||
|
id_dense_start_test++;
|
||||||
|
}
|
||||||
|
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
|
||||||
|
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
|
||||||
|
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||||
|
if (mem_test[id] < targets[id]) {
|
||||||
|
ngl_per_device = ngl_per_device_test;
|
||||||
|
mem = mem_test;
|
||||||
|
id_dense_start = id_dense_start_test;
|
||||||
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n",
|
||||||
|
__func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start);
|
||||||
|
|
||||||
|
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
|
||||||
|
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
|
||||||
|
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||||
|
if (mem_test[id] < targets[id]) {
|
||||||
|
ngl_per_device = ngl_per_device_test;
|
||||||
|
mem = mem_test;
|
||||||
|
id_dense_start = id_dense_start_test;
|
||||||
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n",
|
||||||
|
__func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
|
||||||
|
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
|
||||||
|
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||||
|
if (mem_test[id] < targets[id]) {
|
||||||
|
ngl_per_device = ngl_per_device_test;
|
||||||
|
mem = mem_test;
|
||||||
|
id_dense_start = id_dense_start_test;
|
||||||
|
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n",
|
||||||
|
__func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t projected_margin = dmds_full[id].free - mem[id];
|
||||||
|
LLAMA_LOG_INFO(
|
||||||
|
"%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n",
|
||||||
|
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
|
||||||
|
}
|
||||||
|
|
||||||
|
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_params_fit(
|
||||||
|
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
|
||||||
|
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
|
||||||
|
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
|
||||||
|
const int64_t t0_us = llama_time_us();
|
||||||
|
bool ok = true;
|
||||||
|
try {
|
||||||
|
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level);
|
||||||
|
LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__);
|
||||||
|
} catch (const std::runtime_error & e) {
|
||||||
|
LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what());
|
||||||
|
ok = false;
|
||||||
|
}
|
||||||
|
const int64_t t1_us = llama_time_us();
|
||||||
|
LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6);
|
||||||
|
return ok;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_sampler_chain_params llama_sampler_chain_default_params() {
|
struct llama_sampler_chain_params llama_sampler_chain_default_params() {
|
||||||
struct llama_sampler_chain_params result = {
|
struct llama_sampler_chain_params result = {
|
||||||
/*.no_perf =*/ true,
|
/*.no_perf =*/ true,
|
||||||
|
|
@ -49,6 +692,10 @@ size_t llama_max_devices(void) {
|
||||||
return 16;
|
return 16;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t llama_max_tensor_buft_overrides() {
|
||||||
|
return 4096;
|
||||||
|
}
|
||||||
|
|
||||||
bool llama_supports_mmap(void) {
|
bool llama_supports_mmap(void) {
|
||||||
return llama_mmap::SUPPORTED;
|
return llama_mmap::SUPPORTED;
|
||||||
}
|
}
|
||||||
|
|
@ -108,11 +755,12 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
|
||||||
model.t_start_us = tm.t_start_us;
|
model.t_start_us = tm.t_start_us;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides);
|
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
|
||||||
|
|
||||||
ml.print_info();
|
ml.print_info();
|
||||||
|
|
||||||
model.hparams.vocab_only = params.vocab_only;
|
model.hparams.vocab_only = params.vocab_only;
|
||||||
|
model.hparams.no_alloc = params.no_alloc;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
model.load_arch(ml);
|
model.load_arch(ml);
|
||||||
|
|
|
||||||
|
|
@ -37,4 +37,5 @@ else()
|
||||||
add_subdirectory(cvector-generator)
|
add_subdirectory(cvector-generator)
|
||||||
add_subdirectory(export-lora)
|
add_subdirectory(export-lora)
|
||||||
endif()
|
endif()
|
||||||
|
add_subdirectory(fit-params)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
set(TARGET llama-fit-params)
|
||||||
|
add_executable(${TARGET} fit-params.cpp)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||||
|
|
||||||
|
if(LLAMA_TOOLS_INSTALL)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
endif()
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
# fit-params
|
||||||
|
|
||||||
|
llama.cpp binaries can automatically fit the projected memory use of a model to the free device memory available at runtime,
|
||||||
|
this is controlled using the CLI arguments starting with `-fit`/`--fit`.
|
||||||
|
Internally the code is calling `llama_params_fit` to adjust the `llama_model_params` and `llama_context_params` structs.
|
||||||
|
`llama-fit-params` is a simple utility that prints the CLI arguments corresponding to these adjustments to stdout.
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
# First, run llama-fit-params and store the results in a file:
|
||||||
|
> ./build/bin/llama-fit-params --model /opt/models/qwen_3-30b3a-f16.gguf | tee args.txt
|
||||||
|
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||||
|
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||||
|
ggml_cuda_init: found 1 CUDA devices:
|
||||||
|
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
|
||||||
|
build: 6895 (4341dc8bc) with cc (GCC) 15.2.1 20250813 for x86_64-pc-linux-gnu
|
||||||
|
llama_params_fit_impl: projected to use 61807 MiB of device memory vs. 24077 MiB of free device memory
|
||||||
|
llama_params_fit_impl: cannot fulfill margin of 1024 MiB, need to reduce device memory by 42444 MiB
|
||||||
|
llama_params_fit_impl: context size reduced from 40960 to 4096 -> need 3456 MiB less memory in total
|
||||||
|
llama_params_fit_impl: with only dense weights in device memory there is a total surplus of 16164 MiB
|
||||||
|
llama_params_fit_impl: distributing layers across devices with overflow to next device/system memory:
|
||||||
|
llama_params_fit_impl: - CUDA0 (NVIDIA GeForce RTX 4090): 48 layers (34 overflowing), 19187 MiB used, 1199 MiB free
|
||||||
|
llama_params_fit: successfully fit params to free device memory
|
||||||
|
llama_params_fit: fitting params to free memory took 1.15 seconds
|
||||||
|
Printing fitted CLI arguments to stdout...
|
||||||
|
-c 4096 -ngl 48 -ot blk\.14\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.15\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.16\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.17\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.18\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.19\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.20\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.21\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.22\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.23\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.24\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.25\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.26\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.27\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.28\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.29\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.30\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.31\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.32\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.33\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.34\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.35\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.36\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.37\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.38\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.39\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.40\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.41\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.42\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.43\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.44\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.45\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.46\.ffn_(up|down|gate)_(ch|)exps=CPU,blk\.47\.ffn_(up|down|gate)_(ch|)exps=CPU
|
||||||
|
|
||||||
|
# Next, use those results for a llama.cpp binary:
|
||||||
|
> cat args.txt | xargs ./build/bin/llama-server --model /opt/models/qwen_3-30b3a-f16.gguf
|
||||||
|
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||||
|
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||||
|
ggml_cuda_init: found 1 CUDA devices:
|
||||||
|
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
|
||||||
|
build: 6895 (4341dc8bc) with cc (GCC) 15.2.1 20250813 for x86_64-pc-linux-gnu
|
||||||
|
system info: n_threads = 16, n_threads_batch = 16, total_threads = 32
|
||||||
|
|
||||||
|
system_info: n_threads = 16 (n_threads_batch = 16) / 32 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 |
|
||||||
|
|
||||||
|
main: binding port with default address family
|
||||||
|
main: HTTP server is listening, hostname: 127.0.0.1, port: 8080, http threads: 31
|
||||||
|
main: loading model
|
||||||
|
srv load_model: loading model '/opt/models/qwen_3-30b3a-f16.gguf'
|
||||||
|
llama_params_fit_impl: projected to use 19187 MiB of device memory vs. 24077 MiB of free device memory
|
||||||
|
llama_params_fit_impl: will leave 1199 >= 1024 MiB of free device memory, no changes needed
|
||||||
|
llama_params_fit: successfully fit params to free device memory
|
||||||
|
llama_params_fit: fitting params to free memory took 0.28 seconds
|
||||||
|
[...]
|
||||||
|
main: server is listening on http://127.0.0.1:8080 - starting the main loop
|
||||||
|
srv update_slots: all slots are idle
|
||||||
|
^Csrv operator(): operator(): cleaning up before exit...
|
||||||
|
|
||||||
|
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||||
|
llama_memory_breakdown_print: | - CUDA0 (RTX 4090) | 24077 = 945 + (19187 = 17904 + 384 + 898) + 3945 |
|
||||||
|
llama_memory_breakdown_print: | - Host | 58271 = 58259 + 0 + 12 |
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include "arg.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "log.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
common_params params;
|
||||||
|
|
||||||
|
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_init();
|
||||||
|
llama_backend_init();
|
||||||
|
llama_numa_init(params.numa);
|
||||||
|
auto mparams = common_model_params_to_llama(params);
|
||||||
|
auto cparams = common_context_params_to_llama(params);
|
||||||
|
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||||
|
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
||||||
|
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||||
|
|
||||||
|
LOG_INF("Printing fitted CLI arguments to stdout...\n");
|
||||||
|
std::cout << "-c " << cparams.n_ctx;
|
||||||
|
std::cout << " -ngl " << mparams.n_gpu_layers;
|
||||||
|
|
||||||
|
size_t nd = llama_max_devices();
|
||||||
|
while (nd > 1 && mparams.tensor_split[nd - 1] == 0.0f) {
|
||||||
|
nd--;
|
||||||
|
}
|
||||||
|
if (nd > 1) {
|
||||||
|
for (size_t id = 0; id < nd; id++) {
|
||||||
|
if (id == 0) {
|
||||||
|
std::cout << " -ts ";
|
||||||
|
}
|
||||||
|
if (id > 0) {
|
||||||
|
std::cout << ",";
|
||||||
|
}
|
||||||
|
std::cout << mparams.tensor_split[id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||||
|
for (size_t itbo = 0; itbo < ntbo && mparams.tensor_buft_overrides[itbo].pattern != nullptr; itbo++) {
|
||||||
|
if (itbo == 0) {
|
||||||
|
std::cout << " -ot ";
|
||||||
|
}
|
||||||
|
if (itbo > 0) {
|
||||||
|
std::cout << ",";
|
||||||
|
}
|
||||||
|
std::cout << mparams.tensor_buft_overrides[itbo].pattern << "=" << ggml_backend_buft_name(mparams.tensor_buft_overrides[itbo].buft);
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue