This commit is contained in:
borebot 2026-02-05 03:34:44 +02:00 committed by GitHub
commit d29057584b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 316 additions and 0 deletions

View File

@ -1367,6 +1367,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.use_direct_io = params.use_direct_io;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
mparams.requested_n_ctx = params.n_ctx;
mparams.use_extra_bufts = !params.no_extra_bufts;
mparams.no_host = params.no_host;

View File

@ -306,6 +306,9 @@ extern "C" {
// override key-value pairs of the model meta data
const struct llama_model_kv_override * kv_overrides;
// expected context size for memory allocation planning (0 = auto)
uint32_t requested_n_ctx;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible

View File

@ -18,10 +18,12 @@
#include <algorithm>
#include <cassert>
#include <cfloat>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <functional>
#include <map>
#include <numeric>
#include <regex>
#include <sstream>
#include <stdexcept>
@ -2321,6 +2323,311 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.n_no_rope_layer_step = hparams.n_layer;
}
// KV-cache aware layer distribution for heterogeneous GPUs
if (all_zero && n_devices() > 1 && split_mode == LLAMA_SPLIT_MODE_LAYER) {
// Determine context size for memory planning
uint32_t n_ctx_for_kv = 0;
if (params.requested_n_ctx > 0) {
// Use the explicitly requested context size from model params
n_ctx_for_kv = params.requested_n_ctx;
LLAMA_LOG_INFO("%s: Using requested_n_ctx=%u for KV cache calculation\n",
__func__, n_ctx_for_kv);
} else {
// Use a conservative default for memory planning
n_ctx_for_kv = std::min(32768u, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: Using default n_ctx=%u for KV cache calculation (training context: %u)\n",
__func__, n_ctx_for_kv, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: (set requested_n_ctx in model params to match your actual context size)\n", __func__);
}
// Only apply KV-aware distribution if we have a valid context size
if (n_ctx_for_kv > 0 && n_gpu_layers > 0) {
LLAMA_LOG_INFO("%s: Implementing KV-cache aware layer distribution\n", __func__);
// Calculate memory requirements per layer
const int64_t n_head_kv = hparams.n_head_kv();
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_kv = n_embd_head * n_head_kv;
// KV cache element size (typically f16 = 2 bytes, but can be quantized)
const size_t kv_size_element = 2; // sizeof(ggml_fp16_t)
// Total KV cache size for all layers (K and V)
// KV cache = 2 (K+V) * n_ctx * n_layers * n_embd_kv * element_size
const size_t kv_cache_size_total = 2ULL * n_ctx_for_kv * n_layer * n_embd_kv * kv_size_element;
// Estimate model weight size per layer
const size_t model_size_total = ml.n_bytes;
const size_t weight_size_per_layer = model_size_total / n_layer;
// Calculate actual compute buffer size based on attention matrix requirements
// Attention matrix: n_kv × n_ubatch × n_head × sizeof(float)
// This is the dominant memory consumer during inference
const int64_t n_head = hparams.n_head();
const size_t n_ubatch = 512; // Default physical batch size (from context params)
const size_t compute_buffer_size = n_ctx_for_kv * n_ubatch * n_head * sizeof(float);
const size_t min_overhead = 512ULL * 1024 * 1024; // 512MB base overhead
LLAMA_LOG_INFO("%s: Compute buffer size: %.2f MB (context=%u, ubatch=%zu, heads=%lld)\n",
__func__,
compute_buffer_size / 1024.0 / 1024.0,
n_ctx_for_kv, n_ubatch, (long long)n_head);
// For memory calculation, we need to account for KV cache being shared across layers on each device
// We'll calculate this dynamically during layer assignment
LLAMA_LOG_INFO("%s: Per-layer memory: weights=%.2f MB\n",
__func__,
weight_size_per_layer / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: Total KV cache size: %.2f MB\n",
__func__,
kv_cache_size_total / 1024.0 / 1024.0);
// Get memory info and calculate layer assignments
std::vector<int> layers_per_gpu(n_devices(), 0);
std::vector<size_t> gpu_free_memory(n_devices());
// Get free memory for each device and check if they can handle compute buffers
std::vector<bool> device_excluded(n_devices(), false);
for (size_t i = 0; i < n_devices(); ++i) {
ggml_backend_dev_t dev = devices[i];
size_t total, free;
ggml_backend_dev_memory(dev, &free, &total);
gpu_free_memory[i] = free;
// Check if device can handle minimum requirements (1 layer + compute buffer + KV cache)
size_t min_kv_cache = kv_cache_size_total / n_devices(); // Conservative estimate
size_t min_required = weight_size_per_layer + min_kv_cache + compute_buffer_size + min_overhead;
if (free < min_required) {
device_excluded[i] = true;
LLAMA_LOG_WARN("%s: Device %zu [%s]: %.2f MB free - excluding (needs %.2f MB minimum)\n",
__func__, i, ggml_backend_dev_name(dev),
free / 1024.0 / 1024.0, min_required / 1024.0 / 1024.0);
}
}
// Estimate total memory requirements and warn if insufficient
size_t total_gpu_memory = 0;
for (size_t i = 0; i < n_devices(); ++i) {
total_gpu_memory += gpu_free_memory[i];
}
// Rough estimate: KV cache + model weights + compute buffers (conservative estimate)
size_t estimated_compute_buffers = kv_cache_size_total; // Compute buffers often similar to KV cache size
size_t estimated_total_needed = kv_cache_size_total + model_size_total + estimated_compute_buffers;
if (estimated_total_needed > total_gpu_memory) {
LLAMA_LOG_WARN("%s: Memory estimate: %.2f GB needed vs %.2f GB available\n",
__func__,
estimated_total_needed / 1024.0 / 1024.0 / 1024.0,
total_gpu_memory / 1024.0 / 1024.0 / 1024.0);
LLAMA_LOG_WARN("%s: Context size may be too large for available memory\n", __func__);
}
// Sort devices by available memory (largest first), excluding unusable devices
std::vector<size_t> gpu_indices;
for (size_t i = 0; i < n_devices(); ++i) {
if (!device_excluded[i]) {
gpu_indices.push_back(i);
}
}
std::sort(gpu_indices.begin(), gpu_indices.end(),
[&gpu_free_memory](size_t a, size_t b) {
return gpu_free_memory[a] > gpu_free_memory[b];
});
if (gpu_indices.empty()) {
LLAMA_LOG_ERROR("%s: No GPUs have sufficient memory for compute buffers\n", __func__);
// Fall back to original allocation
return true;
}
// Assign layers greedily to GPUs with most memory first
int act_gpu_layers = n_gpu_layers; // Local copy that can be modified
int remaining_layers = act_gpu_layers;
// First pass: assign layers based on weights only (KV cache and compute buffers handled separately)
size_t weight_per_layer = weight_size_per_layer;
for (size_t idx : gpu_indices) {
// Reserve memory for compute buffer and base overhead
size_t reserved = compute_buffer_size + min_overhead;
if (gpu_free_memory[idx] <= reserved) {
LLAMA_LOG_WARN("%s: Device %zu [%s]: %zu MB free, can't fit compute buffer (%.2f MB)\n",
__func__, idx, ggml_backend_dev_name(devices[idx]),
gpu_free_memory[idx] / 1024 / 1024,
reserved / 1024.0 / 1024.0);
continue;
}
size_t available_for_model = gpu_free_memory[idx] - reserved;
int layers_that_fit = available_for_model / weight_per_layer;
if (layers_that_fit > 0 && remaining_layers > 0) {
int layers_to_assign = std::min(layers_that_fit, remaining_layers);
layers_per_gpu[idx] = layers_to_assign;
remaining_layers -= layers_to_assign;
LLAMA_LOG_INFO("%s: Device %zu [%s]: %zu MB free, assigned %d layers (%.2f MB weights, %.2f MB compute buffer)\n",
__func__, idx, ggml_backend_dev_name(devices[idx]),
gpu_free_memory[idx] / 1024 / 1024,
layers_per_gpu[idx],
(layers_to_assign * weight_per_layer) / 1024.0 / 1024.0,
compute_buffer_size / 1024.0 / 1024.0);
} else {
LLAMA_LOG_WARN("%s: Device %zu [%s]: %zu MB free, assigned 0 layers (need %.2f MB per layer + %.2f MB compute buffer)\n",
__func__, idx, ggml_backend_dev_name(devices[idx]),
gpu_free_memory[idx] / 1024 / 1024,
weight_per_layer / 1024.0 / 1024.0,
compute_buffer_size / 1024.0 / 1024.0);
}
}
// Second pass: iteratively check if KV cache can fit proportionally
bool kv_fit_check_needed = (remaining_layers == 0);
int iterations = 0;
const int max_iterations = 10;
while (kv_fit_check_needed && iterations < max_iterations) {
kv_fit_check_needed = false;
iterations++;
// Calculate current total assigned layers
int total_assigned = 0;
for (size_t idx = 0; idx < n_devices(); ++idx) {
total_assigned += layers_per_gpu[idx];
}
if (total_assigned == 0) break;
// Check KV cache distribution for each device
for (size_t idx = 0; idx < n_devices(); ++idx) {
if (layers_per_gpu[idx] > 0) {
double layer_ratio = (double)layers_per_gpu[idx] / total_assigned;
size_t kv_cache_for_device = (size_t)(kv_cache_size_total * layer_ratio);
size_t weights = layers_per_gpu[idx] * weight_per_layer;
size_t total_memory_needed = weights + kv_cache_for_device + compute_buffer_size + min_overhead;
if (total_memory_needed > gpu_free_memory[idx]) {
// Device can't fit current allocation, reduce layers
size_t available_memory = gpu_free_memory[idx];
if (available_memory > min_overhead + kv_cache_for_device + compute_buffer_size) {
size_t available_for_weights = available_memory - min_overhead - kv_cache_for_device - compute_buffer_size;
int new_layer_count = available_for_weights / weight_per_layer;
new_layer_count = std::max(0, new_layer_count);
if (new_layer_count < layers_per_gpu[idx]) {
LLAMA_LOG_WARN("%s: Device %zu: Reducing layers from %d to %d due to KV cache requirements (%.2f MB KV cache)\n",
__func__, idx, layers_per_gpu[idx], new_layer_count,
kv_cache_for_device / 1024.0 / 1024.0);
remaining_layers += layers_per_gpu[idx] - new_layer_count;
layers_per_gpu[idx] = new_layer_count;
kv_fit_check_needed = true;
}
} else {
// Device can't even fit the minimum requirements
LLAMA_LOG_WARN("%s: Device %zu: Removing all %d layers - insufficient memory for KV cache\n",
__func__, idx, layers_per_gpu[idx]);
remaining_layers += layers_per_gpu[idx];
layers_per_gpu[idx] = 0;
kv_fit_check_needed = true;
}
}
}
}
}
// Third pass: redistribute any remaining layers to devices with available capacity
if (remaining_layers > 0) {
LLAMA_LOG_INFO("%s: Attempting to redistribute %d remaining layers\n", __func__, remaining_layers);
// Calculate current memory usage for each device that has layers assigned
for (size_t idx : gpu_indices) {
if (layers_per_gpu[idx] > 0 && remaining_layers > 0) {
// Calculate current memory usage
int current_assigned = 0;
for (size_t i = 0; i < n_devices(); ++i) {
current_assigned += layers_per_gpu[i];
}
double layer_ratio = (double)layers_per_gpu[idx] / current_assigned;
size_t current_kv_cache = (size_t)(kv_cache_size_total * layer_ratio);
size_t current_weights = layers_per_gpu[idx] * weight_per_layer;
size_t current_usage = current_weights + current_kv_cache + compute_buffer_size + min_overhead;
if (gpu_free_memory[idx] > current_usage) {
// Calculate how many additional layers could fit
// We need to account for proportional increase in KV cache
int additional_layers = 0;
for (int test_layers = 1; test_layers <= remaining_layers; test_layers++) {
int new_total_layers = layers_per_gpu[idx] + test_layers;
int new_total_assigned = current_assigned + test_layers;
double new_layer_ratio = (double)new_total_layers / new_total_assigned;
size_t new_kv_cache = (size_t)(kv_cache_size_total * new_layer_ratio);
size_t new_weights = new_total_layers * weight_per_layer;
size_t new_total_usage = new_weights + new_kv_cache + compute_buffer_size + min_overhead;
if (new_total_usage <= gpu_free_memory[idx]) {
additional_layers = test_layers;
} else {
break;
}
}
if (additional_layers > 0) {
int layers_to_add = std::min(additional_layers, remaining_layers);
layers_per_gpu[idx] += layers_to_add;
remaining_layers -= layers_to_add;
LLAMA_LOG_INFO("%s: Device %zu [%s]: redistributed %d additional layers (total now %d)\n",
__func__, idx, ggml_backend_dev_name(devices[idx]),
layers_to_add, layers_per_gpu[idx]);
}
}
}
}
}
// Warn if we couldn't place all layers
if (remaining_layers > 0) {
LLAMA_LOG_ERROR("%s: WARNING: Could not assign %d layers to GPUs. Consider:\n",
__func__, remaining_layers);
LLAMA_LOG_ERROR("%s: - Reducing context size (current: %u)\n",
__func__, n_ctx_for_kv);
LLAMA_LOG_ERROR("%s: - Using fewer layers (-ngl)\n", __func__);
LLAMA_LOG_ERROR("%s: - Adding more GPU memory\n", __func__);
// Put remaining layers on CPU (will be updated below)
}
// Convert layer counts to split ratios
splits.clear();
splits.resize(n_devices());
float cumsum = 0.0f;
// Calculate total layers actually assigned
int total_assigned_layers = 0;
for (size_t i = 0; i < n_devices(); ++i) {
total_assigned_layers += layers_per_gpu[i];
}
// Update act_gpu_layers to match what we actually assigned
act_gpu_layers = total_assigned_layers;
for (size_t i = 0; i < n_devices(); ++i) {
cumsum += (float)layers_per_gpu[i] / act_gpu_layers;
splits[i] = cumsum;
}
LLAMA_LOG_INFO("%s: Final split ratios: ", __func__);
for (size_t i = 0; i < n_devices(); ++i) {
LLAMA_LOG_CONT("%.3f ", splits[i]);
}
LLAMA_LOG_CONT("\n");
}
}
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
@ -2496,6 +2803,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
const bool is_swa = il < (int) hparams.n_layer && hparams.is_swa(il);
// calculate the split points
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; });
@ -8123,6 +8434,7 @@ llama_model_params llama_model_default_params() {
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.requested_n_ctx =*/ 0,
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_direct_io =*/ false,