Merge 3d578b42b9 into 95a6ebabb2
This commit is contained in:
commit
672e49615b
|
|
@ -1869,7 +1869,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_sparam().set_env("LLAMA_ARG_BACKEND_SAMPLING"));
|
||||
add_opt(common_arg(
|
||||
{"--pooling"}, "{none,mean,cls,last,rank}",
|
||||
{"--pooling"}, "{none,mean,cls,last,rank,max}",
|
||||
"pooling type for embeddings, use model default if unspecified",
|
||||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
||||
|
|
@ -1877,6 +1877,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
||||
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
|
||||
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
|
||||
else if (value == "max") { params.pooling_type = LLAMA_POOLING_TYPE_MAX; }
|
||||
else { throw std::invalid_argument("invalid value"); }
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING"));
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@ extern "C" {
|
|||
LLAMA_POOLING_TYPE_CLS = 2,
|
||||
LLAMA_POOLING_TYPE_LAST = 3,
|
||||
LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
|
||||
LLAMA_POOLING_TYPE_MAX = 5,
|
||||
};
|
||||
|
||||
enum llama_attention_type {
|
||||
|
|
|
|||
|
|
@ -1338,6 +1338,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
case LLAMA_POOLING_TYPE_MAX:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
|
@ -1767,6 +1768,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
case LLAMA_POOLING_TYPE_MAX:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
|
|
|||
|
|
@ -249,6 +249,34 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_max::set_input(const llama_ubatch * ubatch) {
|
||||
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MAX) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
||||
|
||||
GGML_ASSERT(mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer));
|
||||
|
||||
float * data = (float *) mask->data;
|
||||
|
||||
for (int64_t i = 0; i < n_tokens * n_seqs_unq; i++) {
|
||||
data[i] = -INFINITY;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
data[seq_idx*n_tokens + i + j] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
||||
|
|
@ -1748,6 +1776,19 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
|
|||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_max() const {
|
||||
auto inp = std::make_unique<llm_graph_input_max>(cparams);
|
||||
|
||||
auto & cur = inp->mask;
|
||||
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_cls() const {
|
||||
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
|
||||
|
||||
|
|
@ -2676,6 +2717,29 @@ void llm_graph_context::build_pooling(
|
|||
cur = ggml_soft_max(ctx0, cur);
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MAX:
|
||||
{
|
||||
// [n_embd, n_tokens] -> [n_embd, n_tokens, n_seqs_unq]
|
||||
ggml_tensor * inp_3d = ggml_reshape_3d(ctx0, inp, n_embd, n_tokens, 1);
|
||||
ggml_tensor * inp_expanded = ggml_repeat(ctx0, inp_3d,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, n_tokens, ubatch.n_seqs_unq));
|
||||
|
||||
ggml_tensor * inp_max_mask = build_inp_max();
|
||||
// [n_tokens, n_seqs_unq] -> reshape to [1, n_tokens, n_seqs_unq]
|
||||
ggml_tensor * mask_3d = ggml_reshape_3d(ctx0, inp_max_mask, 1, n_tokens, ubatch.n_seqs_unq);
|
||||
|
||||
// broadcast
|
||||
ggml_tensor * inp_masked = ggml_add(ctx0, inp_expanded, mask_3d);
|
||||
|
||||
// Permute to [n_tokens, n_embd, n_seqs_unq, 1] for pooling along dim 0
|
||||
ggml_tensor * inp_perm = ggml_cont(ctx0, ggml_permute(ctx0, inp_masked, 1, 0, 2, 3));
|
||||
|
||||
// Global max pool over the full token dimension -> [1, n_embd, n_seqs_unq, 1]
|
||||
cur = ggml_pool_2d(ctx0, inp_perm, GGML_OP_POOL_MAX, n_tokens, 1, n_tokens, 1, 0, 0);
|
||||
|
||||
// Reshape to [n_embd, n_seqs_unq]
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, ubatch.n_seqs_unq);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
|
|
|
|||
|
|
@ -222,6 +222,18 @@ public:
|
|||
const llm_arch arch;
|
||||
};
|
||||
|
||||
class llm_graph_input_max : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_max(const llama_cparams & cparams) : cparams(cparams) {}
|
||||
virtual ~llm_graph_input_max() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * mask; // F32 [n_tokens, n_seqs_unq]
|
||||
|
||||
const llama_cparams cparams;
|
||||
};
|
||||
|
||||
class llm_graph_input_rs : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
|
||||
|
|
@ -863,6 +875,7 @@ struct llm_graph_context {
|
|||
ggml_tensor * build_inp_out_ids() const;
|
||||
ggml_tensor * build_inp_mean() const;
|
||||
ggml_tensor * build_inp_cls() const;
|
||||
ggml_tensor * build_inp_max() const;
|
||||
|
||||
ggml_tensor * build_inp_cross_embd() const;
|
||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ For the full list of features, please refer to [server's changelog](https://gith
|
|||
| `-sp, --special` | special tokens output enabled (default: false) |
|
||||
| `--warmup, --no-warmup` | whether to perform warmup with an empty run (default: enabled) |
|
||||
| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
|
||||
| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
|
||||
| `--pooling {none,mean,cls,last,rank,max}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
|
||||
| `-np, --parallel N` | number of server slots (default: -1, -1 = auto)<br/>(env: LLAMA_ARG_N_PARALLEL) |
|
||||
| `-cb, --cont-batching, -nocb, --no-cont-batching` | whether to enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
|
||||
| `-mm, --mmproj FILE` | path to a multimodal projector file. see tools/mtmd/README.md<br/>note: if -hf is used, this argument can be omitted<br/>(env: LLAMA_ARG_MMPROJ) |
|
||||
|
|
|
|||
|
|
@ -289,3 +289,57 @@ def test_embedding_openai_library_base64():
|
|||
# make sure the decoded data is the same as the original
|
||||
for x, y in zip(floats, vec0):
|
||||
assert abs(x - y) < EPSILON
|
||||
|
||||
|
||||
def test_embedding_pooling_max():
|
||||
global server
|
||||
server.pooling = 'max'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": "I believe the meaning of life is",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 1
|
||||
assert 'embedding' in res.body['data'][0]
|
||||
assert len(res.body['data'][0]['embedding']) > 1
|
||||
|
||||
# make sure embedding vector is normalized
|
||||
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
|
||||
|
||||
|
||||
def test_embedding_pooling_max_multiple():
|
||||
global server
|
||||
server.pooling = 'max'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"Write a joke about AI",
|
||||
"This is a test",
|
||||
"This is another test",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 4
|
||||
for d in res.body['data']:
|
||||
assert 'embedding' in d
|
||||
assert len(d['embedding']) > 1
|
||||
|
||||
|
||||
def test_embedding_pooling_max_same_prompt():
|
||||
global server
|
||||
server.pooling = 'max'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 2
|
||||
# same input should give same output
|
||||
v0 = res.body['data'][0]['embedding']
|
||||
v1 = res.body['data'][1]['embedding']
|
||||
for x, y in zip(v0, v1):
|
||||
assert abs(x - y) < EPSILON
|
||||
|
|
|
|||
Loading…
Reference in New Issue