diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 14dccac5b5..1f7a52d789 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -512,7 +512,12 @@ void llama_context::sched_reserve() { if (cparams.fused_gdn_ch) { // more than one token in the batch per sequence in order to take the chunked path - auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true); + // note: n_outputs must match n_tokens for embedding models with mean/rank pooling, + // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies + // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens, + // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553). + const uint32_t n_tokens_ch = 16*n_seqs; + auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true); if (!gf) { throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)"); } diff --git a/tools/server/tests/unit/test_embedding.py b/tools/server/tests/unit/test_embedding.py index 50601b8396..17ba09554b 100644 --- a/tools/server/tests/unit/test_embedding.py +++ b/tools/server/tests/unit/test_embedding.py @@ -101,6 +101,40 @@ def test_embedding_mixed_input(input, is_multi_prompt: bool): assert len(data[0]['embedding']) > 1 +def test_embedding_pooling_mean(): + global server + server.pooling = 'mean' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "I believe the meaning of life is", + }) + assert res.status_code == 200 + assert len(res.body['data']) == 1 + assert 'embedding' in res.body['data'][0] + assert len(res.body['data'][0]['embedding']) > 1 + + # make sure embedding vector is normalized + assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON + + +def test_embedding_pooling_mean_multiple(): + global server + server.pooling = 'mean' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "Write a joke about AI", + "This is a test", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 3 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + def test_embedding_pooling_none(): global server server.pooling = 'none'