diff --git a/common/arg.cpp b/common/arg.cpp index 538d2a4b0a..c152cdd876 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1869,7 +1869,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_sparam().set_env("LLAMA_ARG_BACKEND_SAMPLING")); add_opt(common_arg( - {"--pooling"}, "{none,mean,cls,last,rank}", + {"--pooling"}, "{none,mean,cls,last,rank,max}", "pooling type for embeddings, use model default if unspecified", [](common_params & params, const std::string & value) { /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } @@ -1877,6 +1877,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } + else if (value == "max") { params.pooling_type = LLAMA_POOLING_TYPE_MAX; } else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING")); diff --git a/include/llama.h b/include/llama.h index a940f9d648..7c3ac23362 100644 --- a/include/llama.h +++ b/include/llama.h @@ -174,6 +174,7 @@ extern "C" { LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_LAST = 3, LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph + LLAMA_POOLING_TYPE_MAX = 5, }; enum llama_attention_type { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a808e3e454..038f2f9e33 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1338,6 +1338,7 @@ int llama_context::encode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: + case LLAMA_POOLING_TYPE_MAX: { // extract sequence embeddings auto & embd_seq_out = embd_seq; @@ -1767,6 +1768,7 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: + case LLAMA_POOLING_TYPE_MAX: { // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = embd_seq; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0e7d96ca10..9794ea8863 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -249,6 +249,34 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_max::set_input(const llama_ubatch * ubatch) { + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MAX) { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs_unq = ubatch->n_seqs_unq; + + GGML_ASSERT(mask); + GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer)); + + float * data = (float *) mask->data; + + for (int64_t i = 0; i < n_tokens * n_seqs_unq; i++) { + data[i] = -INFINITY; + } + + for (int i = 0; i < n_tokens; i += n_seq_tokens) { + for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[i][s]; + const int32_t seq_idx = ubatch->seq_idx[seq_id]; + + for (int j = 0; j < n_seq_tokens; ++j) { + data[seq_idx*n_tokens + i + j] = 0.0f; + } + } + } + } +} + void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; const int64_t n_seqs_unq = ubatch->n_seqs_unq; @@ -1748,6 +1776,19 @@ ggml_tensor * llm_graph_context::build_inp_mean() const { return cur; } +ggml_tensor * llm_graph_context::build_inp_max() const { + auto inp = std::make_unique(cparams); + + auto & cur = inp->mask; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + ggml_tensor * llm_graph_context::build_inp_cls() const { auto inp = std::make_unique(cparams, arch); @@ -2676,6 +2717,29 @@ void llm_graph_context::build_pooling( cur = ggml_soft_max(ctx0, cur); } } break; + case LLAMA_POOLING_TYPE_MAX: + { + // [n_embd, n_tokens] -> [n_embd, n_tokens, n_seqs_unq] + ggml_tensor * inp_3d = ggml_reshape_3d(ctx0, inp, n_embd, n_tokens, 1); + ggml_tensor * inp_expanded = ggml_repeat(ctx0, inp_3d, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, n_tokens, ubatch.n_seqs_unq)); + + ggml_tensor * inp_max_mask = build_inp_max(); + // [n_tokens, n_seqs_unq] -> reshape to [1, n_tokens, n_seqs_unq] + ggml_tensor * mask_3d = ggml_reshape_3d(ctx0, inp_max_mask, 1, n_tokens, ubatch.n_seqs_unq); + + // broadcast + ggml_tensor * inp_masked = ggml_add(ctx0, inp_expanded, mask_3d); + + // Permute to [n_tokens, n_embd, n_seqs_unq, 1] for pooling along dim 0 + ggml_tensor * inp_perm = ggml_cont(ctx0, ggml_permute(ctx0, inp_masked, 1, 0, 2, 3)); + + // Global max pool over the full token dimension -> [1, n_embd, n_seqs_unq, 1] + cur = ggml_pool_2d(ctx0, inp_perm, GGML_OP_POOL_MAX, n_tokens, 1, n_tokens, 1, 0, 0); + + // Reshape to [n_embd, n_seqs_unq] + cur = ggml_reshape_2d(ctx0, cur, n_embd, ubatch.n_seqs_unq); + } break; default: { GGML_ABORT("unknown pooling type"); diff --git a/src/llama-graph.h b/src/llama-graph.h index bb0ad75198..cb63691adb 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -222,6 +222,18 @@ public: const llm_arch arch; }; +class llm_graph_input_max : public llm_graph_input_i { +public: + llm_graph_input_max(const llama_cparams & cparams) : cparams(cparams) {} + virtual ~llm_graph_input_max() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * mask; // F32 [n_tokens, n_seqs_unq] + + const llama_cparams cparams; +}; + class llm_graph_input_rs : public llm_graph_input_i { public: llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} @@ -863,6 +875,7 @@ struct llm_graph_context { ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; + ggml_tensor * build_inp_max() const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; diff --git a/tools/server/README.md b/tools/server/README.md index 1bd8201689..551dd034ea 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -172,7 +172,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `-sp, --special` | special tokens output enabled (default: false) | | `--warmup, --no-warmup` | whether to perform warmup with an empty run (default: enabled) | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | -| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | +| `--pooling {none,mean,cls,last,rank,max}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | | `-np, --parallel N` | number of server slots (default: -1, -1 = auto)
(env: LLAMA_ARG_N_PARALLEL) | | `-cb, --cont-batching, -nocb, --no-cont-batching` | whether to enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-mm, --mmproj FILE` | path to a multimodal projector file. see tools/mtmd/README.md
note: if -hf is used, this argument can be omitted
(env: LLAMA_ARG_MMPROJ) | diff --git a/tools/server/tests/unit/test_embedding.py b/tools/server/tests/unit/test_embedding.py index 17ba09554b..02ba575f06 100644 --- a/tools/server/tests/unit/test_embedding.py +++ b/tools/server/tests/unit/test_embedding.py @@ -289,3 +289,57 @@ def test_embedding_openai_library_base64(): # make sure the decoded data is the same as the original for x, y in zip(floats, vec0): assert abs(x - y) < EPSILON + + +def test_embedding_pooling_max(): + global server + server.pooling = 'max' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "I believe the meaning of life is", + }) + assert res.status_code == 200 + assert len(res.body['data']) == 1 + assert 'embedding' in res.body['data'][0] + assert len(res.body['data'][0]['embedding']) > 1 + + # make sure embedding vector is normalized + assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON + + +def test_embedding_pooling_max_multiple(): + global server + server.pooling = 'max' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "Write a joke about AI", + "This is a test", + "This is another test", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + +def test_embedding_pooling_max_same_prompt(): + global server + server.pooling = 'max' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 2 + # same input should give same output + v0 = res.body['data'][0]['embedding'] + v1 = res.body['data'][1]['embedding'] + for x, y in zip(v0, v1): + assert abs(x - y) < EPSILON