CUDA: cache intermediate tensors

This commit is contained in:
Aman Gupta 2026-01-01 20:59:35 +08:00
parent c8a3798041
commit 23d04b313c
3 changed files with 238 additions and 44 deletions

View File

@ -5,6 +5,7 @@
#include "ggml-cuda.h"
#include <cstdint>
#include <limits>
#include <memory>
#if defined(GGML_USE_HIP)
@ -1218,6 +1219,104 @@ struct ggml_cuda_stream_context {
}
};
// cache to extend lifetimes of ggml_cuda_pool_alloc, ggml_cuda_pool expects memory to allocated in a LIFO order
// hence this cache works like a stack
struct ggml_cuda_cache {
struct pool_alloc {
ggml_cuda_pool * pool;
void * ptr;
size_t actual_size;
pool_alloc():
pool{nullptr}
, ptr{nullptr}
, actual_size{0}
{}
template<typename T>
pool_alloc(ggml_cuda_pool_alloc<T> && other) {
pool = other.pool;
ptr = (void *)other.ptr;
actual_size = other.actual_size;
other.ptr = nullptr;
other.pool = nullptr;
other.actual_size = 0;
}
pool_alloc(pool_alloc && other) {
pool = other.pool;
ptr = (void *) other.ptr;
actual_size = other.actual_size;
other.ptr = nullptr;
other.pool = nullptr;
other.actual_size = 0;
}
~pool_alloc() {
if (ptr != nullptr) {
pool->free(ptr, actual_size);
}
}
};
struct cache_entry {
int layout; // mmq_q8_1_ds_layout value
std::vector<pool_alloc> pool_ptrs;
size_t ttl_nodes{};
cache_entry() = default;
cache_entry(cache_entry && other) = default;
cache_entry& operator=(cache_entry && other) = default;
cache_entry(const cache_entry &) = delete;
cache_entry& operator=(const cache_entry &) = delete;
~cache_entry() {
// Free pool allocations in reverse order (LIFO)
while (!pool_ptrs.empty()) {
pool_ptrs.pop_back();
}
}
};
void clear_cache() {
remove_expired(std::numeric_limits<size_t>::max());
entries.clear();
}
void remove_expired(size_t node_count) {
// max lifetime of cache entry - 10 nodes after
while (!entries.empty() && entries.back().second.ttl_nodes + 10 <= node_count) {
entries.pop_back();
}
}
cache_entry * find(const ggml_tensor * node, int layout) {
for (auto & entry: entries) {
if (entry.first == node && entry.second.layout == layout) {
return &entry.second;
}
}
return nullptr;
}
~ggml_cuda_cache() {
while (!entries.empty()) {
entries.pop_back();
}
}
void add_entry(const ggml_tensor * node, cache_entry && entry) {
entries.emplace_back(node, std::move(entry));
}
std::vector<std::pair<const ggml_tensor *, cache_entry>> entries;
};
struct ggml_backend_cuda_context {
int device;
std::string name;
@ -1229,6 +1328,7 @@ struct ggml_backend_cuda_context {
std::unique_ptr<ggml_cuda_graph> cuda_graph;
int curr_stream_no = 0;
size_t node_count = 0;
explicit ggml_backend_cuda_context(int device) :
device(device),
@ -1266,6 +1366,7 @@ struct ggml_backend_cuda_context {
// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
std::unique_ptr<ggml_cuda_cache> caches[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]{{}};
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
@ -1276,6 +1377,27 @@ struct ggml_backend_cuda_context {
return *pools[device][curr_stream_no];
}
ggml_cuda_cache & cache(int device, int stream) {
if (caches[device][stream] == nullptr) {
caches[device][stream] = std::unique_ptr<ggml_cuda_cache>(new ggml_cuda_cache());
}
return *caches[device][stream];
}
ggml_cuda_cache & cache() {
return cache(device, curr_stream_no);
}
void clear_cache() {
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (caches[i][j]) {
caches[i][j]->clear_cache();
}
}
}
}
ggml_cuda_pool & pool() {
return pool(device);
}

View File

@ -3232,6 +3232,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
ggml_cuda_concurrent_event * concurrent_event = nullptr;
bool should_launch_concurrent_events = false;
cuda_ctx->clear_cache();
const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
concurrent_event = &stream_ctx.concurrent_events[node];
@ -3662,6 +3664,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}
GGML_ASSERT(ok);
// Increment node counter and expire old cache entries
cuda_ctx->node_count++;
cuda_ctx->cache().remove_expired(cuda_ctx->node_count);
if (!is_concurrent_event_active) {
try_launch_concurrent_event(node);
}

View File

@ -1,4 +1,5 @@
#include "common.cuh"
#include "ggml.h"
#include "mmq.cuh"
#include "quantize.cuh"
#include "mmid.cuh"
@ -118,25 +119,53 @@ void ggml_cuda_mul_mat_q(
// TODO: tighter pool buffer size vs q8 path
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
ggml_cuda_cache & cache = ctx.cache();
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
const int layout = use_native_mxfp4 ? -1 : mmq_get_q8_1_ds_layout(src0->type);
void * src1_ptr = nullptr;
ggml_cuda_cache::cache_entry * entry = cache.find(src1, layout);
if (entry != nullptr) {
GGML_ASSERT(entry->pool_ptrs.size() == 1);
size_t expected_size = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
GGML_ASSERT(entry->pool_ptrs[0].actual_size >= expected_size);
src1_ptr = entry->pool_ptrs[0].ptr;
GGML_ASSERT(src1_ptr != nullptr);
} else {
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
}
CUDA_CHECK(cudaGetLastError());
}
CUDA_CHECK(cudaGetLastError());
src1_ptr = src1_q8_1.get();
std::vector<ggml_cuda_cache::pool_alloc> allocs;
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(src1_q8_1)));
cache.add_entry(
src1,
ggml_cuda_cache::cache_entry{
layout,
std::move(allocs),
ctx.node_count
});
}
// Stride depends on quantization format
@ -148,7 +177,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s13 = ne12*s12;
const mmq_args args = {
src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
src0_d, src0->type, (const int *) src1_ptr, nullptr, nullptr, dst_d,
ne00, ne01, ne1, s01, ne11, s1,
ne02, ne12, s02, s12, s2,
ne03, ne13, s03, s13, s3,
@ -165,41 +194,78 @@ void ggml_cuda_mul_mat_q(
const int64_t ne_get_rows = ne12 * n_expert_used;
GGML_ASSERT(ne1 == n_expert_used);
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
const int layout = use_native_mxfp4 ? -1 : mmq_get_q8_1_ds_layout(src0->type);
{
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11;
void * ids_dst_ptr = nullptr;
void * expert_bounds_ptr = nullptr;
void * src1_q8_1_ptr = nullptr;
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
CUDA_CHECK(cudaGetLastError());
}
ggml_cuda_cache::cache_entry * entry = cache.find(src1, layout);
if (entry != nullptr) {
GGML_ASSERT(entry->pool_ptrs.size() == 4);
ids_dst_ptr = entry->pool_ptrs[1].ptr;
expert_bounds_ptr = entry->pool_ptrs[2].ptr;
src1_q8_1_ptr = entry->pool_ptrs[3].ptr;
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
size_t expected_q8_1_size = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
GGML_ASSERT(entry->pool_ptrs[3].actual_size >= expected_q8_1_size);
} else {
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
const int64_t ne11_flat = ne12*n_expert_used;
const int64_t ne12_flat = 1;
const int64_t ne13_flat = 1;
{
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11;
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
CUDA_CHECK(cudaGetLastError());
}
CUDA_CHECK(cudaGetLastError());
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
const int64_t ne11_flat = ne12*n_expert_used;
const int64_t ne12_flat = 1;
const int64_t ne13_flat = 1;
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
}
CUDA_CHECK(cudaGetLastError());
}
void * ids_src1_ptr = ids_src1.get();
ids_dst_ptr = ids_dst.get();
expert_bounds_ptr = expert_bounds.get();
src1_q8_1_ptr = src1_q8_1.get();
std::vector<ggml_cuda_cache::pool_alloc> allocs;
// Store in allocation order; custom destructor will free in reverse (LIFO)
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(ids_src1)));
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(ids_dst)));
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(expert_bounds)));
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(src1_q8_1)));
cache.add_entry(
src1,
ggml_cuda_cache::cache_entry{
layout,
std::move(allocs),
ctx.node_count
});
}
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
@ -208,7 +274,7 @@ void ggml_cuda_mul_mat_q(
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
const mmq_args args = {
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
src0_d, src0->type, (const int *) src1_q8_1_ptr, (int32_t *) ids_dst_ptr, (int32_t *) expert_bounds_ptr, dst_d,
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,