server : host-memory prompt caching (#16391)
* minor : code style * server : fix prompt similarity calculation * server : initial host-memory prompt caching * cont * server : refactor * cont * cont : make the server task of the slot const * cont : minor [no ci] * server : cache prompts and checkpoints only for completion tasks * server : improve prompt caching logic * cont : fix check for number of cached prompts [no ci] * server : improve caching logic, add -cram CLI arg * server : print prompt mismatch info * cont : better naming [no ci] * server : improve prompt cache loading logic * server : add option to debug the slot contents (#16482) * server : add option to debug the slot contents * Update tools/server/server.cpp --------- Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> * server : add option to disable prompt cache --------- Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
This commit is contained in:
parent
8328fd4bae
commit
d00cbea63c
|
|
@ -1935,6 +1935,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.n_ctx_checkpoints = value;
|
params.n_ctx_checkpoints = value;
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
|
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--cache-ram", "-cram"}, "N",
|
||||||
|
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n"
|
||||||
|
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.cache_ram_mib = value;
|
||||||
|
}
|
||||||
|
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--kv-unified", "-kvu"},
|
{"--kv-unified", "-kvu"},
|
||||||
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,8 @@ struct common_chat_msg_content_part {
|
||||||
struct common_chat_msg {
|
struct common_chat_msg {
|
||||||
std::string role;
|
std::string role;
|
||||||
std::string content;
|
std::string content;
|
||||||
std::vector<common_chat_msg_content_part> content_parts = {};
|
std::vector<common_chat_msg_content_part> content_parts;
|
||||||
std::vector<common_chat_tool_call> tool_calls = {};
|
std::vector<common_chat_tool_call> tool_calls;
|
||||||
std::string reasoning_content;
|
std::string reasoning_content;
|
||||||
std::string tool_name;
|
std::string tool_name;
|
||||||
std::string tool_call_id;
|
std::string tool_call_id;
|
||||||
|
|
@ -44,7 +44,7 @@ struct common_chat_msg {
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||||
}
|
}
|
||||||
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||||
if (ids_cache.size() <= i) {
|
if (ids_cache.size() <= i) {
|
||||||
auto id = tool_calls[i].id;
|
auto id = tool_calls[i].id;
|
||||||
|
|
|
||||||
|
|
@ -425,7 +425,8 @@ struct common_params {
|
||||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||||
int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot
|
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||||
|
int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc.
|
||||||
|
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
|
|
|
||||||
|
|
@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache(
|
||||||
throw std::runtime_error("failed to create ggml context for kv cache");
|
throw std::runtime_error("failed to create ggml context for kv cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * k;
|
ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
|
||||||
ggml_tensor * v;
|
ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
|
||||||
|
|
||||||
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
|
|
||||||
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
|
|
||||||
|
|
||||||
ggml_format_name(k, "cache_k_l%d", il);
|
ggml_format_name(k, "cache_k_l%d", il);
|
||||||
ggml_format_name(v, "cache_v_l%d", il);
|
ggml_format_name(v, "cache_v_l%d", il);
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -66,8 +66,7 @@ def test_server_slots():
|
||||||
assert len(res.body) == server.n_slots
|
assert len(res.body) == server.n_slots
|
||||||
assert server.n_ctx is not None and server.n_slots is not None
|
assert server.n_ctx is not None and server.n_slots is not None
|
||||||
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
|
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
|
||||||
assert "params" in res.body[0]
|
assert "params" not in res.body[0]
|
||||||
assert res.body[0]["params"]["seed"] == server.seed
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_split_model():
|
def test_load_split_model():
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@ def create_server():
|
||||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
|
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
|
||||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
|
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
|
||||||
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
|
||||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
|
||||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
|
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
|
||||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
|
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
|
||||||
]
|
]
|
||||||
|
|
@ -54,7 +54,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
||||||
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||||
[
|
[
|
||||||
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||||
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
|
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ def create_server():
|
||||||
|
|
||||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
|
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
|
||||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
|
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
|
||||||
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
|
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
|
||||||
])
|
])
|
||||||
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
|
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
|
||||||
global server
|
global server
|
||||||
|
|
@ -41,7 +41,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
|
||||||
|
|
||||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||||
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||||
])
|
])
|
||||||
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
||||||
global server
|
global server
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,12 @@ from utils import *
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
|
SHORT_TEXT = """
|
||||||
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||||
|
""".strip()
|
||||||
|
|
||||||
LONG_TEXT = """
|
LONG_TEXT = """
|
||||||
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
|
@ -21,19 +27,18 @@ def create_server():
|
||||||
|
|
||||||
|
|
||||||
def test_ctx_shift_enabled():
|
def test_ctx_shift_enabled():
|
||||||
# the prompt is 301 tokens
|
# the prompt is 226 tokens
|
||||||
# the slot context is 512/2 = 256 tokens
|
# the slot context is 512/2 = 256 tokens
|
||||||
# the prompt is truncated to keep the last (301 - 256/2) = 173 tokens
|
|
||||||
# 96 tokens are generated thanks to shifting the context when it gets full
|
# 96 tokens are generated thanks to shifting the context when it gets full
|
||||||
global server
|
global server
|
||||||
server.enable_ctx_shift = True
|
server.enable_ctx_shift = True
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"n_predict": 96,
|
"n_predict": 96,
|
||||||
"prompt": LONG_TEXT,
|
"prompt": SHORT_TEXT,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert res.body["timings"]["prompt_n"] == 173
|
assert res.body["timings"]["prompt_n"] == 226
|
||||||
assert res.body["timings"]["predicted_n"] == 96
|
assert res.body["timings"]["predicted_n"] == 96
|
||||||
assert res.body["truncated"] is True
|
assert res.body["truncated"] is True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,10 +31,10 @@
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||||
|
|
||||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
|
@ -1102,6 +1102,7 @@ public:
|
||||||
~server_tokens() = default;
|
~server_tokens() = default;
|
||||||
|
|
||||||
// Prevent copying
|
// Prevent copying
|
||||||
|
// TODO: server_tokens should be copyable - remove this:
|
||||||
server_tokens(const server_tokens&) = delete;
|
server_tokens(const server_tokens&) = delete;
|
||||||
server_tokens& operator=(const server_tokens&) = delete;
|
server_tokens& operator=(const server_tokens&) = delete;
|
||||||
|
|
||||||
|
|
@ -1119,7 +1120,7 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
|
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
|
||||||
|
|
||||||
// for debugging
|
// for debugging
|
||||||
std::string str() const {
|
std::string str() const {
|
||||||
|
|
@ -1144,9 +1145,8 @@ public:
|
||||||
auto it = map_pos_to_media.find(pos);
|
auto it = map_pos_to_media.find(pos);
|
||||||
if (it != map_pos_to_media.end()) {
|
if (it != map_pos_to_media.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Chunk not found");
|
|
||||||
}
|
}
|
||||||
|
throw std::runtime_error("Chunk not found");
|
||||||
}
|
}
|
||||||
|
|
||||||
void push_back(llama_token tok) {
|
void push_back(llama_token tok) {
|
||||||
|
|
@ -1170,7 +1170,7 @@ public:
|
||||||
map_pos_to_media[start_pos] = std::move(new_chunk);
|
map_pos_to_media[start_pos] = std::move(new_chunk);
|
||||||
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
size_t n_tokens;
|
size_t n_tokens;
|
||||||
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
||||||
for (size_t i = 0; i < n_tokens; ++i) {
|
for (size_t i = 0; i < n_tokens; ++i) {
|
||||||
push_back(text_tokens[i]);
|
push_back(text_tokens[i]);
|
||||||
}
|
}
|
||||||
|
|
@ -1190,7 +1190,7 @@ public:
|
||||||
// We could also just check, but this will prevent silently dropping MTMD data.
|
// We could also just check, but this will prevent silently dropping MTMD data.
|
||||||
GGML_ASSERT(has_mtmd);
|
GGML_ASSERT(has_mtmd);
|
||||||
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
|
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
|
||||||
auto chunk = tokens.map_pos_to_media[it->first].get();
|
auto * chunk = tokens.map_pos_to_media[it->first].get();
|
||||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||||
map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
|
map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
|
||||||
}
|
}
|
||||||
|
|
@ -1271,33 +1271,52 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t get_common_prefix(const server_tokens & b) const {
|
size_t get_common_prefix(const server_tokens & b) const {
|
||||||
size_t max_idx = std::min(tokens.size(), b.tokens.size());
|
const size_t max_idx = std::min(tokens.size(), b.tokens.size());
|
||||||
|
|
||||||
|
if (!has_mtmd) {
|
||||||
for (size_t i = 0; i < max_idx; ++i) {
|
for (size_t i = 0; i < max_idx; ++i) {
|
||||||
auto & ai = tokens[i];
|
if (tokens[i] == b.tokens[i]) {
|
||||||
auto & bi = b.tokens[i];
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return max_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < max_idx; ++i) {
|
||||||
|
const llama_token ai = tokens[i];
|
||||||
|
const llama_token bi = b.tokens[i];
|
||||||
|
|
||||||
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
|
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
|
||||||
GGML_ASSERT(has_mtmd);
|
|
||||||
const auto & a_chunk = find_chunk(i);
|
const auto & a_chunk = find_chunk(i);
|
||||||
const auto & b_chunk = b.find_chunk(i);
|
const auto & b_chunk = b.find_chunk(i);
|
||||||
|
|
||||||
GGML_ASSERT(a_chunk && b_chunk);
|
GGML_ASSERT(a_chunk && b_chunk);
|
||||||
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
|
|
||||||
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
|
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
|
||||||
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
|
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
|
||||||
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
|
|
||||||
if (ai_id == bi_id && a_pos == b_pos) {
|
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
|
||||||
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
|
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
|
||||||
i += a_pos - 1; // will be +1 by the for loop
|
|
||||||
|
if (id_ai == id_bi && pos_a == pos_b) {
|
||||||
|
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
|
||||||
|
i += pos_a - 1; // will be +1 by the for loop
|
||||||
continue;
|
continue;
|
||||||
} else {
|
}
|
||||||
|
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
} else if (ai == bi) {
|
|
||||||
|
if (ai == bi) {
|
||||||
continue;
|
continue;
|
||||||
} else {
|
}
|
||||||
|
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return max_idx; // all tokens are equal
|
return max_idx; // all tokens are equal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1308,7 +1327,7 @@ public:
|
||||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||||
|
|
||||||
for (size_t i = 0; i < tokens.size(); ++i) {
|
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||||
auto & t = tokens[i];
|
const auto & t = tokens[i];
|
||||||
if (t == LLAMA_TOKEN_NULL) {
|
if (t == LLAMA_TOKEN_NULL) {
|
||||||
try {
|
try {
|
||||||
const auto & chunk = find_chunk(i);
|
const auto & chunk = find_chunk(i);
|
||||||
|
|
@ -1330,8 +1349,8 @@ public:
|
||||||
mtmd_context * mctx,
|
mtmd_context * mctx,
|
||||||
llama_pos n_past,
|
llama_pos n_past,
|
||||||
int32_t seq_id,
|
int32_t seq_id,
|
||||||
llama_pos & n_pos_out) {
|
llama_pos & n_pos_out) const {
|
||||||
auto & chunk = find_chunk(n_past);
|
const auto & chunk = find_chunk(n_past);
|
||||||
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
||||||
? "image" : "audio";
|
? "image" : "audio";
|
||||||
SRV_INF("processing %s...\n", name);
|
SRV_INF("processing %s...\n", name);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue