lookup, lookahead: fix crash when n_ctx not specified (#18729)

* lookup, lookahead: fix crash when n_ctx not specified

Since PR #16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR #4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR #4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR #10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR #16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR #16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.

* lookahead: fix n_seq_max and kv_unified configuration

Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR #14482 changed validation logic.

Consolidates fix from PR #18730 per maintainer request.

Commit message drafted with Claude.
This commit is contained in:
Daniele Pinna 2026-01-30 21:10:24 +01:00 committed by GitHub
parent 4927795810
commit 1488339138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 2 deletions

View File

@ -50,6 +50,12 @@ int main(int argc, char ** argv) {
const int N = 5; // n-gram size
const int G = 15; // max verification n-grams
// lookahead requires W + G + 1 sequences for parallel Jacobi decoding
params.n_parallel = W + G + 1;
// unified KV cache is required for coupled sequences in batch splitting
params.kv_unified = true;
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
@ -115,7 +121,7 @@ int main(int argc, char ** argv) {
// seq_id == 0 : the current input token
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
// seq_id [W + 1, W + G] : verification n-grams
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
llama_batch batch = llama_batch_init(llama_n_ctx(ctx), 0, W + G + 1);
// target model sampling context
struct common_sampler * smpl = common_sampler_init(model, params.sampling);

View File

@ -106,7 +106,7 @@ int main(int argc, char ** argv){
std::vector<llama_token> draft;
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
llama_batch batch_tgt = llama_batch_init(llama_n_ctx(ctx), 0, 1);
const auto t_dec_start = ggml_time_us();