llama: consistent ctx <-> buf order for KV cache (#16746)

This commit is contained in:
Johannes Gäßler 2025-10-28 11:23:54 +01:00 committed by GitHub
parent 280d97be96
commit 7a0e900e36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 41 additions and 33 deletions

View File

@ -8,6 +8,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <cstring>
#include <limits> #include <limits>
#include <map> #include <map>
#include <stdexcept> #include <stdexcept>
@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
const uint32_t n_layer_kv = hparams.n_layer_kv(); const uint32_t n_layer_kv = hparams.n_layer_kv();
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
// create a context for each buffer type // create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft); auto it = ctx_map.find(buft);
if (it == ctx_map.end()) { if (it == ctx_map.end()) {
@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
return nullptr; return nullptr;
} }
ctx_map[buft] = ctx; ctx_map.emplace(buft, ctx);
ctxs.emplace_back(ctx);
return ctx; return ctx;
} }
return it->second; return it->second.get();
}; };
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max); GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache(
} }
// allocate tensors and initialize the buffers to avoid NaNs in the padding // allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) { for (auto & [buft, ctx] : ctx_map) {
auto * buft = it.first; ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
auto * ctx = it.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (!buf) { if (!buf) {
throw std::runtime_error("failed to allocate buffer for kv cache"); throw std::runtime_error("failed to allocate buffer for kv cache");
} }
@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache(
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
ggml_backend_buffer_clear(buf, 0); ggml_backend_buffer_clear(buf, 0);
bufs.emplace_back(buf); ctxs_bufs.emplace_back(std::move(ctx), buf);
} }
{ {
@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) {
} }
if (data) { if (data) {
for (auto & buf : bufs) { for (auto & [_, buf] : ctxs_bufs) {
ggml_backend_buffer_clear(buf.get(), 0); ggml_backend_buffer_clear(buf.get(), 0);
} }
} }
@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const { std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret; std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { for (const auto & [_, buf] : ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
} }
return ret; return ret;
} }
@ -1298,7 +1302,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
size_t llama_kv_cache::total_size() const { size_t llama_kv_cache::total_size() const {
size_t size = 0; size_t size = 0;
for (const auto & buf : bufs) { for (const auto & [_, buf] : ctxs_bufs) {
size += ggml_backend_buffer_get_size(buf.get()); size += ggml_backend_buffer_get_size(buf.get());
} }

View File

@ -217,8 +217,8 @@ private:
// this is the SWA type of the cache - not to be confused with the model SWA type // this is the SWA type of the cache - not to be confused with the model SWA type
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs; // ggml contexts for the KV cache along with the allocated backend buffers:
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method // note: this is not part of the KV state and it's only used to speed-up the find_slot() method

View File

@ -7,6 +7,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cstring>
#include <limits> #include <limits>
#include <map> #include <map>
#include <stdexcept> #include <stdexcept>
@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent(
cells.clear(); cells.clear();
cells.resize(mem_size); cells.resize(mem_size);
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
// create a context for each buffer type // create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft); auto it = ctx_map.find(buft);
if (it == ctx_map.end()) { if (it == ctx_map.end()) {
@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent(
return nullptr; return nullptr;
} }
ctx_map[buft] = ctx; ctx_map.emplace(buft, ctx);
ctxs.emplace_back(ctx);
return ctx; return ctx;
} }
return it->second; return it->second.get();
}; };
r_l.resize(n_layer); r_l.resize(n_layer);
@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent(
} }
// allocate tensors and initialize the buffers to avoid NaNs in the padding // allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) { for (auto & [buft, ctx] : ctx_map) {
auto * buft = it.first; ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
auto * ctx = it.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (!buf) { if (!buf) {
throw std::runtime_error("failed to allocate buffer for rs cache"); throw std::runtime_error("failed to allocate buffer for rs cache");
} }
ggml_backend_buffer_clear(buf, 0); ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
bufs.emplace_back(buf); ctxs_bufs.emplace_back(std::move(ctx), buf);
} }
{ {
@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) {
used = 0; used = 0;
if (data) { if (data) {
for (auto & buf : bufs) { for (auto & [_, buf] : ctxs_bufs) {
ggml_backend_buffer_clear(buf.get(), 0); ggml_backend_buffer_clear(buf.get(), 0);
} }
} }
@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const { std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret; std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { for (const auto & [_, buf] : ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
} }
return ret; return ret;
} }
@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const {
size_t llama_memory_recurrent::total_size() const { size_t llama_memory_recurrent::total_size() const {
size_t size = 0; size_t size = 0;
for (const auto & buf : bufs) { for (const auto & [_, buf] : ctxs_bufs) {
size += ggml_backend_buffer_get_size(buf.get()); size += ggml_backend_buffer_get_size(buf.get());
} }

View File

@ -109,8 +109,8 @@ private:
const uint32_t n_seq_max = 1; const uint32_t n_seq_max = 1;
std::vector<ggml_context_ptr> ctxs; // ggml contexts for the KV cache along with the allocated backend buffers:
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
size_t total_size() const; size_t total_size() const;

View File

@ -2231,7 +2231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// define a comparator for the buft -> ctx map to ensure that the order is well-defined: // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator { struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs); return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
} }
}; };
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map; std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;