llama : fix pooling assertion crash in chunked GDN detection path (#20468)

* llama : fix pooling assertion crash in chunked GDN detection path

The chunked fused Gated Delta Net detection in sched_reserve() calls
graph_reserve(16*n_seqs, n_seqs, n_outputs, ...) where n_outputs = n_seqs.
This creates a dimension mismatch in build_pooling() for embedding models
with mean/rank pooling: build_inp_mean() creates a tensor with shape
[n_tokens=16*n_seqs, ...] while t_embd is reduced to [n_outputs=n_seqs, ...]
via out_ids, causing ggml_mul_mat to assert on ggml_can_mul_mat(a, b).

Fix: pass n_tokens as n_outputs in the chunked GDN graph reservation,
matching the pattern used by the pp/tg worst-case reservations.

Regression introduced by #20340 (d28961d).
Same class of bug as #12517, fixed by #12545.

* server : add mean pooling tests to embedding test suite

Add test_embedding_pooling_mean and test_embedding_pooling_mean_multiple
to cover the --pooling mean codepath, which was previously untested.

These tests would have caught the regression introduced by #20340 where
build_pooling() crashes with a ggml_mul_mat assertion due to mismatched
dimensions in the chunked GDN detection path.

---------

Co-authored-by: Domenico Crupi <domenico@zerovolt.it>
This commit is contained in:
ZeroV0LT 2026-03-13 19:53:42 +01:00 committed by GitHub
parent d7ba99c485
commit f17b3be63f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 1 deletions

View File

@ -512,7 +512,12 @@ void llama_context::sched_reserve() {
if (cparams.fused_gdn_ch) {
// more than one token in the batch per sequence in order to take the chunked path
auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true);
// note: n_outputs must match n_tokens for embedding models with mean/rank pooling,
// because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies
// it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens,
// the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553).
const uint32_t n_tokens_ch = 16*n_seqs;
auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
}

View File

@ -101,6 +101,40 @@ def test_embedding_mixed_input(input, is_multi_prompt: bool):
assert len(data[0]['embedding']) > 1
def test_embedding_pooling_mean():
global server
server.pooling = 'mean'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1
# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
def test_embedding_pooling_mean_multiple():
global server
server.pooling = 'mean'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI",
"This is a test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 3
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
def test_embedding_pooling_none():
global server
server.pooling = 'none'