Compare commits

...

11 Commits

Author SHA1 Message Date
Aadeshveer Singh 58062860af
ggml : use WARP_SIZE/2 for argmax reduction offset (#18092) 2025-12-17 11:47:01 +08:00
Yuri Khrustalev 2973a65ecb
gguf-py : allow converting multi-tensor models from read-only locations (#18100) 2025-12-17 02:27:03 +01:00
Johannes Gäßler d0794e89d9
llama-fit-params: force disable mlock (#18103) 2025-12-17 00:50:12 +01:00
Johannes Gäßler 9dcac6cf9f
llama-fit-params: lower ctx size for multi GPU (#18101) 2025-12-17 00:49:34 +01:00
Johannes Gäßler 0e49a7b8b4
llama-fit-params: fix underflow for dense models (#18095) 2025-12-17 00:47:37 +01:00
Johannes Gäßler 4164596c76
llama-fit-params: QoL impr. for prints/errors (#18089) 2025-12-17 00:03:19 +01:00
Xuan-Son Nguyen ef83fb8601
model: fix LFM2 missing tensors (#18105) 2025-12-16 19:07:43 +01:00
Johannes Gäßler ec98e20021
llama: fix early stop in params_fit if ctx is set (#18070) 2025-12-16 14:24:00 +01:00
yifant-code 59977eba7b
server: fix crash when batch > ubatch with embeddings (#17912)
* server: fix crash when batch > ubatch with embeddings (#12836)

Fixes #12836 where the server crashes with GGML_ASSERT failure when
running with embeddings enabled and n_batch > n_ubatch.

Root cause: Embeddings use non-causal attention which requires all
tokens to be processed within a single ubatch. When n_batch > n_ubatch,
the server attempts to split processing, causing assertion failure.

Solution:
- Add parameter validation in main() after common_params_parse()
- When embeddings enabled and n_batch > n_ubatch:
  * Log warnings explaining the issue
  * Automatically set n_batch = n_ubatch
  * Prevent server crash

This follows the approach suggested by @ggerganov in issue #12836.

Note: This supersedes stalled PR #12940 which attempted a runtime fix
in the old examples/server/server.cpp location. This implementation
validates at startup in tools/server/server.cpp (current location).

Testing:
- Build: Compiles successfully
- Validation triggers: Warns when -b > -ub with --embedding
- Auto-correction works: Adjusts n_batch = n_ubatch
- No false positives: Valid params don't trigger warnings
- Verified on macOS M3 Pro with embedding model

* Update tools/server/server.cpp

---------

Co-authored-by: ytian218 <ytian218@bloomberg.net>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-12-16 14:27:36 +02:00
Daniel Bevenius 79dbae034a
model-conversion : remove -fa option in model card template [no ci] (#18088)
This commit updates the causal model card template and removes the
-fa option as it is no longer required (fa is auto detected).
2025-12-16 13:25:09 +01:00
Xuan-Son Nguyen 7f2b2f3c77
arch: refactor LLM_TENSOR_NAMES (#18051)
* arch: refactor LLM_TENSOR_NAMES

* update docs

* typo

* fix LLM_ARCH_NEMOTRON_H_MOE

* show more meaningful error message on missing tensor

* fix and tested LLM_ARCH_NEMOTRON_H_MOE
2025-12-16 13:22:30 +01:00
10 changed files with 1985 additions and 2332 deletions

View File

@ -97,7 +97,7 @@ The model params and tensors layout must be defined in `llama.cpp` source files:
1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
2. In `src/llama-arch.cpp`:
- Add the architecture name to the `LLM_ARCH_NAMES` map.
- Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
- Add the list of model tensors to `llm_get_tensor_names` (you may also need to update `LLM_TENSOR_NAMES`)
3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.

View File

@ -7,7 +7,7 @@ base_model:
Recommended way to run this model:
```sh
llama-server -hf {namespace}/{model_name}-GGUF -c 0 -fa
llama-server -hf {namespace}/{model_name}-GGUF -c 0
```
Then, access http://localhost:8080

View File

@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
argmax = shared_argmax[lane_id];
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {

View File

@ -288,7 +288,7 @@ class LocalTensor:
data_range: LocalTensorRange
def mmap_bytes(self) -> np.ndarray:
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
return np.memmap(self.data_range.filename, mode='r', offset=self.data_range.offset, shape=self.data_range.size)
class SafetensorsLocal:

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
#include "ggml.h" // ggml_op
#include <string>
#include <set>
//
// gguf constants (sync with gguf.py)
@ -316,6 +317,7 @@ enum llm_tensor {
LLM_TENSOR_DENSE_3_OUT,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT_NORM_LFM2, // fix for wrong tensor name
LLM_TENSOR_ROPE_FREQS,
LLM_TENSOR_ROPE_FACTORS_LONG,
LLM_TENSOR_ROPE_FACTORS_SHORT,
@ -526,6 +528,10 @@ struct LLM_TN_IMPL {
const int bid;
const int xid;
const std::set<llm_tensor> model_tensors;
LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid);
std::string str() const;
operator std::string() const {
@ -547,11 +553,11 @@ struct LLM_TN {
llm_arch arch;
LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
return { arch, tensor, suffix, bid, xid };
return LLM_TN_IMPL(arch, tensor, suffix, bid, xid);
}
LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
return { arch, tensor, nullptr, bid, xid };
return LLM_TN_IMPL(arch, tensor, nullptr, bid, xid);
}
};

View File

@ -6236,8 +6236,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);

View File

@ -71,8 +71,9 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
}, &ud);
llama_model_params mparams_copy = *mparams;
mparams_copy.no_alloc = true;
mparams_copy.use_mmap = false;
mparams_copy.no_alloc = true;
mparams_copy.use_mmap = false;
mparams_copy.use_mlock = false;
llama_model * model = llama_model_load_from_file(path_model, mparams_copy);
if (model == nullptr) {
@ -180,11 +181,12 @@ static void llama_params_fit_impl(
}
}
int64_t sum_total = 0;
int64_t sum_projected_free = 0;
int64_t min_projected_free = INT64_MAX;
int64_t sum_projected_used = 0;
int64_t sum_projected_ctx = 0;
int64_t sum_total = 0;
int64_t sum_projected_free = 0;
int64_t min_projected_free = INT64_MAX;
int64_t sum_projected_used = 0;
int64_t sum_projected_model = 0;
int64_t sum_projected_ctx = 0;
if (nd > 1) {
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
@ -195,11 +197,12 @@ static void llama_params_fit_impl(
const int64_t projected_used = dmd.mb.total();
const int64_t projected_free = dmd.free - projected_used;
sum_total += dmd.total;
sum_projected_used += projected_used;
sum_projected_free += projected_free;
min_projected_free = std::min(min_projected_free, projected_free);
sum_projected_ctx += dmd.mb.context;
sum_total += dmd.total;
sum_projected_used += projected_used;
sum_projected_free += projected_free;
min_projected_free = std::min(min_projected_free, projected_free);
sum_projected_model += dmd.mb.model;
sum_projected_ctx += dmd.mb.context;
if (nd > 1) {
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
@ -234,13 +237,34 @@ static void llama_params_fit_impl(
if (cparams->n_ctx == 0) {
if (hp_nct > n_ctx_min) {
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
const uint32_t ctx_reduction = std::min(
uint32_t((-global_surplus + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
int64_t memory_reduction = -global_surplus;
if (nd > 1) {
// for multiple devices we need to be more conservative in terms of how much context we think can fit:
// - for dense models only whole layers can be assigned to devices
// - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer
// - on average we expect a waste of 0.5 layers/tensors per device
// - use slightly more than the expected average for nd devices to be safe
const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
memory_reduction += (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
}
uint32_t ctx_reduction = std::min(uint32_t((memory_reduction + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
cparams->n_ctx = hp_nct - ctx_reduction;
const int64_t memory_reduction = ctx_reduction * bytes_per_ctx;
cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
ctx_reduction = hp_nct - cparams->n_ctx;
memory_reduction = ctx_reduction * bytes_per_ctx;
global_surplus += memory_reduction;
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
if (global_surplus >= 0) {
if (nd == 1) {
LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__);
return;
}
LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__);
}
} else {
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
__func__, hp_nct, n_ctx_min);
@ -249,10 +273,6 @@ static void llama_params_fit_impl(
LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx);
}
}
if (global_surplus >= 0) {
LLAMA_LOG_INFO("%s: entire model can be fit across devices by reducing context\n", __func__);
return;
}
}
if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) {
@ -478,8 +498,13 @@ static void llama_params_fit_impl(
} else {
LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__);
}
uint32_t n_unassigned = hp_ngl;
for (int id = nd - 1; id >= 0; id--) {
uint32_t n_unassigned = hp_ngl;
for (size_t jd = id + 1; jd < nd; ++jd) {
assert(n_unassigned >= ngl_per_device[jd].n_layer);
n_unassigned -= ngl_per_device[jd].n_layer;
}
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
ngl_per_device_high[id].n_layer = n_unassigned;
if (hp_nex > 0) {
@ -488,7 +513,9 @@ static void llama_params_fit_impl(
if (ngl_per_device_high[id].n_layer > 0) {
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
if (mem_high[id] > targets[id]) {
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
while (delta > 1) {
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
step_size = std::max(step_size, uint32_t(1));
@ -502,20 +529,19 @@ static void llama_params_fit_impl(
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
if (mem_test[id] <= targets[id]) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
n_unassigned -= ngl_per_device[id].n_layer;
ngl_per_device = ngl_per_device_test;
mem = mem_test;
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
} else {
ngl_per_device_high = ngl_per_device_test;
mem_high = mem_test;
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer);
}
delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
}
} else {
ngl_per_device = ngl_per_device_high;
n_unassigned -= ngl_per_device[id].n_layer;
assert(ngl_per_device_high[id].n_layer == n_unassigned);
ngl_per_device = ngl_per_device_high;
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
}
}

View File

@ -4,7 +4,11 @@
#include "common.h"
#include "log.h"
#include <iostream>
#include <chrono>
#include <cinttypes>
#include <thread>
using namespace std::chrono_literals;
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@ -22,13 +26,17 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
auto mparams = common_model_params_to_llama(params);
auto cparams = common_context_params_to_llama(params);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
const bool success = llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
if (!success) {
LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__);
exit(1);
}
LOG_INF("Printing fitted CLI arguments to stdout...\n");
std::cout << "-c " << cparams.n_ctx;
std::cout << " -ngl " << mparams.n_gpu_layers;
LOG_INF("%s: printing fitted CLI arguments to stdout...\n", __func__);
std::this_thread::sleep_for(10ms); // to avoid a race between stderr and stdout
printf("-c %" PRIu32 " -ngl %" PRIu32, cparams.n_ctx, mparams.n_gpu_layers);
size_t nd = llama_max_devices();
while (nd > 1 && mparams.tensor_split[nd - 1] == 0.0f) {
@ -37,26 +45,22 @@ int main(int argc, char ** argv) {
if (nd > 1) {
for (size_t id = 0; id < nd; id++) {
if (id == 0) {
std::cout << " -ts ";
printf(" -ts ");
}
if (id > 0) {
std::cout << ",";
}
std::cout << mparams.tensor_split[id];
printf("%s%" PRIu32, id > 0 ? "," : "", uint32_t(mparams.tensor_split[id]));
}
}
const size_t ntbo = llama_max_tensor_buft_overrides();
bool any_tbo = false;
for (size_t itbo = 0; itbo < ntbo && mparams.tensor_buft_overrides[itbo].pattern != nullptr; itbo++) {
if (itbo == 0) {
std::cout << " -ot ";
printf(" -ot \"");
}
if (itbo > 0) {
std::cout << ",";
}
std::cout << mparams.tensor_buft_overrides[itbo].pattern << "=" << ggml_backend_buft_name(mparams.tensor_buft_overrides[itbo].buft);
printf("%s%s=%s", itbo > 0 ? "," : "", mparams.tensor_buft_overrides[itbo].pattern, ggml_backend_buft_name(mparams.tensor_buft_overrides[itbo].buft));
any_tbo = true;
}
std::cout << "\n";
printf("%s\n", any_tbo ? "\"" : "");
return 0;
}

View File

@ -73,8 +73,18 @@ int main(int argc, char ** argv, char ** envp) {
return 1;
}
// validate batch size for embeddings
// embeddings require all tokens to be processed in a single ubatch
// see https://github.com/ggml-org/llama.cpp/issues/12836
if (params.embedding && params.n_batch > params.n_ubatch) {
LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch);
LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch);
params.n_batch = params.n_ubatch;
}
if (params.n_parallel < 0) {
LOG_INF("%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__);
params.n_parallel = 4;
params.kv_unified = true;
}