From 3827a23255a7b6c5dfc9ce7396eb0e4c0fd8330f Mon Sep 17 00:00:00 2001 From: ytian218 Date: Tue, 16 Dec 2025 23:11:04 -0500 Subject: [PATCH] server: validate n_batch == n_ubatch for embeddings (#6263) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #6263 where server accepts mismatched batch/ubatch values with embeddings, leading to suboptimal or incorrect behavior. Problem: Embeddings and reranking use non-causal attention which requires all tokens to be processed within a single ubatch. When n_batch != n_ubatch, the configuration is incoherent. Default values differ (n_batch=2048, n_ubatch=512), so users encounter this frequently. Solution: - Add parameter validation in main() after common_params_parse() - When embeddings enabled and n_batch != n_ubatch: * Log warnings explaining the requirement * Automatically set both to min(n_batch, n_ubatch) * Ensure coherent configuration This follows the auto-correction approach suggested by @mirekphd and provides better UX than strict rejection. Testing: ✅ Builds successfully ✅ Validation triggers: -b 2048 -ub 512 --embedding → logs warnings, adjusts both to 512 ✅ No false positives: -b 512 -ub 512 --embedding → no warnings ✅ Verified on macOS M3 Pro with embedding model --- tools/server/server.cpp | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 8538427f73..81e75a307f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -73,17 +73,23 @@ int main(int argc, char ** argv, char ** envp) { return 1; } - // validate batch size for embeddings - // embeddings require all tokens to be processed in a single ubatch - // see https://github.com/ggml-org/llama.cpp/issues/12836 - if (params.embedding && params.n_batch > params.n_ubatch) { - LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch); - LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch); - params.n_batch = params.n_ubatch; + // validate batch size for embeddings and reranking + // non-causal attention (embeddings/reranking) requires n_batch == n_ubatch + // see https://github.com/ggml-org/llama.cpp/issues/6263 + if (params.embedding && params.n_batch != params.n_ubatch) { + LOG_WRN("%s: embeddings/reranking mode requires n_batch == n_ubatch\n", __func__); + LOG_WRN("%s: setting both to min(%d, %d) = %d to avoid configuration issues\n", + __func__, params.n_batch, params.n_ubatch, + std::min(params.n_batch, params.n_ubatch)); + params.n_batch = params.n_ubatch = std::min(params.n_batch, params.n_ubatch); } - if (params.n_parallel < 0) { - LOG_INF("%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__); + // TODO: should we have a separate n_parallel parameter for the server? + // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 + // TODO: this is a common configuration that is suitable for most local use cases + // however, overriding the parameters is a bit confusing - figure out something more intuitive + if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { + LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); params.n_parallel = 4; params.kv_unified = true;