This commit is contained in:
Lorenzo 2026-04-01 23:50:56 +03:00 committed by GitHub
commit 672e49615b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 137 additions and 2 deletions

View File

@ -1869,7 +1869,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam().set_env("LLAMA_ARG_BACKEND_SAMPLING"));
add_opt(common_arg(
{"--pooling"}, "{none,mean,cls,last,rank}",
{"--pooling"}, "{none,mean,cls,last,rank,max}",
"pooling type for embeddings, use model default if unspecified",
[](common_params & params, const std::string & value) {
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
@ -1877,6 +1877,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else if (value == "max") { params.pooling_type = LLAMA_POOLING_TYPE_MAX; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING"));

View File

@ -174,6 +174,7 @@ extern "C" {
LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3,
LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
LLAMA_POOLING_TYPE_MAX = 5,
};
enum llama_attention_type {

View File

@ -1338,6 +1338,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_MAX:
{
// extract sequence embeddings
auto & embd_seq_out = embd_seq;
@ -1767,6 +1768,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_MAX:
{
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = embd_seq;

View File

@ -249,6 +249,34 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_max::set_input(const llama_ubatch * ubatch) {
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MAX) {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
GGML_ASSERT(mask);
GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer));
float * data = (float *) mask->data;
for (int64_t i = 0; i < n_tokens * n_seqs_unq; i++) {
data[i] = -INFINITY;
}
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
const int32_t seq_idx = ubatch->seq_idx[seq_id];
for (int j = 0; j < n_seq_tokens; ++j) {
data[seq_idx*n_tokens + i + j] = 0.0f;
}
}
}
}
}
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
@ -1748,6 +1776,19 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
return cur;
}
ggml_tensor * llm_graph_context::build_inp_max() const {
auto inp = std::make_unique<llm_graph_input_max>(cparams);
auto & cur = inp->mask;
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
ggml_set_input(cur);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_inp_cls() const {
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
@ -2676,6 +2717,29 @@ void llm_graph_context::build_pooling(
cur = ggml_soft_max(ctx0, cur);
}
} break;
case LLAMA_POOLING_TYPE_MAX:
{
// [n_embd, n_tokens] -> [n_embd, n_tokens, n_seqs_unq]
ggml_tensor * inp_3d = ggml_reshape_3d(ctx0, inp, n_embd, n_tokens, 1);
ggml_tensor * inp_expanded = ggml_repeat(ctx0, inp_3d,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, n_tokens, ubatch.n_seqs_unq));
ggml_tensor * inp_max_mask = build_inp_max();
// [n_tokens, n_seqs_unq] -> reshape to [1, n_tokens, n_seqs_unq]
ggml_tensor * mask_3d = ggml_reshape_3d(ctx0, inp_max_mask, 1, n_tokens, ubatch.n_seqs_unq);
// broadcast
ggml_tensor * inp_masked = ggml_add(ctx0, inp_expanded, mask_3d);
// Permute to [n_tokens, n_embd, n_seqs_unq, 1] for pooling along dim 0
ggml_tensor * inp_perm = ggml_cont(ctx0, ggml_permute(ctx0, inp_masked, 1, 0, 2, 3));
// Global max pool over the full token dimension -> [1, n_embd, n_seqs_unq, 1]
cur = ggml_pool_2d(ctx0, inp_perm, GGML_OP_POOL_MAX, n_tokens, 1, n_tokens, 1, 0, 0);
// Reshape to [n_embd, n_seqs_unq]
cur = ggml_reshape_2d(ctx0, cur, n_embd, ubatch.n_seqs_unq);
} break;
default:
{
GGML_ABORT("unknown pooling type");

View File

@ -222,6 +222,18 @@ public:
const llm_arch arch;
};
class llm_graph_input_max : public llm_graph_input_i {
public:
llm_graph_input_max(const llama_cparams & cparams) : cparams(cparams) {}
virtual ~llm_graph_input_max() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * mask; // F32 [n_tokens, n_seqs_unq]
const llama_cparams cparams;
};
class llm_graph_input_rs : public llm_graph_input_i {
public:
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
@ -863,6 +875,7 @@ struct llm_graph_context {
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_max() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;

View File

@ -172,7 +172,7 @@ For the full list of features, please refer to [server's changelog](https://gith
| `-sp, --special` | special tokens output enabled (default: false) |
| `--warmup, --no-warmup` | whether to perform warmup with an empty run (default: enabled) |
| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
| `--pooling {none,mean,cls,last,rank,max}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
| `-np, --parallel N` | number of server slots (default: -1, -1 = auto)<br/>(env: LLAMA_ARG_N_PARALLEL) |
| `-cb, --cont-batching, -nocb, --no-cont-batching` | whether to enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
| `-mm, --mmproj FILE` | path to a multimodal projector file. see tools/mtmd/README.md<br/>note: if -hf is used, this argument can be omitted<br/>(env: LLAMA_ARG_MMPROJ) |

View File

@ -289,3 +289,57 @@ def test_embedding_openai_library_base64():
# make sure the decoded data is the same as the original
for x, y in zip(floats, vec0):
assert abs(x - y) < EPSILON
def test_embedding_pooling_max():
global server
server.pooling = 'max'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1
# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
def test_embedding_pooling_max_multiple():
global server
server.pooling = 'max'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI",
"This is a test",
"This is another test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 4
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
def test_embedding_pooling_max_same_prompt():
global server
server.pooling = 'max'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 2
# same input should give same output
v0 = res.body['data'][0]['embedding']
v1 = res.body['data'][1]['embedding']
for x, y in zip(v0, v1):
assert abs(x - y) < EPSILON