diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 62e618850b..ee2254b7db 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -5,6 +5,7 @@ #include "ggml-cuda.h" #include +#include #include #if defined(GGML_USE_HIP) @@ -1218,6 +1219,104 @@ struct ggml_cuda_stream_context { } }; +// cache to extend lifetimes of ggml_cuda_pool_alloc, ggml_cuda_pool expects memory to allocated in a LIFO order +// hence this cache works like a stack +struct ggml_cuda_cache { + struct pool_alloc { + ggml_cuda_pool * pool; + void * ptr; + size_t actual_size; + + pool_alloc(): + pool{nullptr} + , ptr{nullptr} + , actual_size{0} + {} + + template + pool_alloc(ggml_cuda_pool_alloc && other) { + pool = other.pool; + ptr = (void *)other.ptr; + actual_size = other.actual_size; + + other.ptr = nullptr; + other.pool = nullptr; + other.actual_size = 0; + } + + pool_alloc(pool_alloc && other) { + pool = other.pool; + ptr = (void *) other.ptr; + actual_size = other.actual_size; + other.ptr = nullptr; + other.pool = nullptr; + other.actual_size = 0; + } + + ~pool_alloc() { + if (ptr != nullptr) { + pool->free(ptr, actual_size); + } + } + }; + + struct cache_entry { + int layout; // mmq_q8_1_ds_layout value + std::vector pool_ptrs; + size_t ttl_nodes{}; + + cache_entry() = default; + + cache_entry(cache_entry && other) = default; + cache_entry& operator=(cache_entry && other) = default; + + cache_entry(const cache_entry &) = delete; + cache_entry& operator=(const cache_entry &) = delete; + + ~cache_entry() { + // Free pool allocations in reverse order (LIFO) + while (!pool_ptrs.empty()) { + pool_ptrs.pop_back(); + } + } + }; + + void clear_cache() { + remove_expired(std::numeric_limits::max()); + entries.clear(); + } + + void remove_expired(size_t node_count) { + // max lifetime of cache entry - 10 nodes after + while (!entries.empty() && entries.back().second.ttl_nodes + 10 <= node_count) { + entries.pop_back(); + } + } + + cache_entry * find(const ggml_tensor * node, int layout) { + for (auto & entry: entries) { + if (entry.first == node && entry.second.layout == layout) { + return &entry.second; + } + } + return nullptr; + } + + ~ggml_cuda_cache() { + while (!entries.empty()) { + entries.pop_back(); + } + } + + void add_entry(const ggml_tensor * node, cache_entry && entry) { + entries.emplace_back(node, std::move(entry)); + } + + std::vector> entries; +}; + + + struct ggml_backend_cuda_context { int device; std::string name; @@ -1229,6 +1328,7 @@ struct ggml_backend_cuda_context { std::unique_ptr cuda_graph; int curr_stream_no = 0; + size_t node_count = 0; explicit ggml_backend_cuda_context(int device) : device(device), @@ -1266,6 +1366,7 @@ struct ggml_backend_cuda_context { // pool std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; + std::unique_ptr caches[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]{{}}; static std::unique_ptr new_pool_for_device(int device, int stream_no); @@ -1276,6 +1377,27 @@ struct ggml_backend_cuda_context { return *pools[device][curr_stream_no]; } + ggml_cuda_cache & cache(int device, int stream) { + if (caches[device][stream] == nullptr) { + caches[device][stream] = std::unique_ptr(new ggml_cuda_cache()); + } + return *caches[device][stream]; + } + + ggml_cuda_cache & cache() { + return cache(device, curr_stream_no); + } + + void clear_cache() { + for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) { + for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) { + if (caches[i][j]) { + caches[i][j]->clear_cache(); + } + } + } + } + ggml_cuda_pool & pool() { return pool(device); } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 55e1c20c96..7574536b3b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3232,6 +3232,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx ggml_cuda_concurrent_event * concurrent_event = nullptr; bool should_launch_concurrent_events = false; + cuda_ctx->clear_cache(); + const auto try_launch_concurrent_event = [&](const ggml_tensor * node) { if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) { concurrent_event = &stream_ctx.concurrent_events[node]; @@ -3662,6 +3664,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } GGML_ASSERT(ok); + // Increment node counter and expire old cache entries + cuda_ctx->node_count++; + cuda_ctx->cache().remove_expired(cuda_ctx->node_count); + if (!is_concurrent_event_active) { try_launch_concurrent_event(node); } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 85692d4543..ce30afcc4f 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,4 +1,5 @@ #include "common.cuh" +#include "ggml.h" #include "mmq.cuh" #include "quantize.cuh" #include "mmid.cuh" @@ -118,25 +119,53 @@ void ggml_cuda_mul_mat_q( // TODO: tighter pool buffer size vs q8 path const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; + ggml_cuda_cache & cache = ctx.cache(); + if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); - ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); - { - const int64_t s11 = src1->nb[1] / ts_src1; - const int64_t s12 = src1->nb[2] / ts_src1; - const int64_t s13 = src1->nb[3] / ts_src1; - if (use_native_mxfp4) { - static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); - quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + const int layout = use_native_mxfp4 ? -1 : mmq_get_q8_1_ds_layout(src0->type); + + void * src1_ptr = nullptr; + ggml_cuda_cache::cache_entry * entry = cache.find(src1, layout); + if (entry != nullptr) { + GGML_ASSERT(entry->pool_ptrs.size() == 1); + size_t expected_size = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + GGML_ASSERT(entry->pool_ptrs[0].actual_size >= expected_size); + src1_ptr = entry->pool_ptrs[0].ptr; + GGML_ASSERT(src1_ptr != nullptr); + } else { + + ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[3] / ts_src1; + if (use_native_mxfp4) { + static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); + quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + ne11, ne12, ne13, stream); + + } else { + quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); - - } else { - quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, - ne11, ne12, ne13, stream); + } + CUDA_CHECK(cudaGetLastError()); } - CUDA_CHECK(cudaGetLastError()); + + src1_ptr = src1_q8_1.get(); + + std::vector allocs; + allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(src1_q8_1))); + + cache.add_entry( + src1, + ggml_cuda_cache::cache_entry{ + layout, + std::move(allocs), + ctx.node_count + }); } // Stride depends on quantization format @@ -148,7 +177,7 @@ void ggml_cuda_mul_mat_q( const int64_t s13 = ne12*s12; const mmq_args args = { - src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d, + src0_d, src0->type, (const int *) src1_ptr, nullptr, nullptr, dst_d, ne00, ne01, ne1, s01, ne11, s1, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, @@ -165,41 +194,78 @@ void ggml_cuda_mul_mat_q( const int64_t ne_get_rows = ne12 * n_expert_used; GGML_ASSERT(ne1 == n_expert_used); - ggml_cuda_pool_alloc ids_src1(ctx.pool(), ne_get_rows); - ggml_cuda_pool_alloc ids_dst(ctx.pool(), ne_get_rows); - ggml_cuda_pool_alloc expert_bounds(ctx.pool(), ne02 + 1); + const int layout = use_native_mxfp4 ? -1 : mmq_get_q8_1_ds_layout(src0->type); - { - GGML_ASSERT(ids->nb[0] == ggml_element_size(ids)); - const int si1 = ids->nb[1] / ggml_element_size(ids); - const int sis1 = nb12 / nb11; + void * ids_dst_ptr = nullptr; + void * expert_bounds_ptr = nullptr; + void * src1_q8_1_ptr = nullptr; - ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - CUDA_CHECK(cudaGetLastError()); - } + ggml_cuda_cache::cache_entry * entry = cache.find(src1, layout); + if (entry != nullptr) { + GGML_ASSERT(entry->pool_ptrs.size() == 4); + ids_dst_ptr = entry->pool_ptrs[1].ptr; + expert_bounds_ptr = entry->pool_ptrs[2].ptr; + src1_q8_1_ptr = entry->pool_ptrs[3].ptr; - const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + - get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); - ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); + size_t expected_q8_1_size = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + GGML_ASSERT(entry->pool_ptrs[3].actual_size >= expected_q8_1_size); + } else { + ggml_cuda_pool_alloc ids_src1(ctx.pool(), ne_get_rows); + ggml_cuda_pool_alloc ids_dst(ctx.pool(), ne_get_rows); + ggml_cuda_pool_alloc expert_bounds(ctx.pool(), ne02 + 1); - const int64_t ne11_flat = ne12*n_expert_used; - const int64_t ne12_flat = 1; - const int64_t ne13_flat = 1; + { + GGML_ASSERT(ids->nb[0] == ggml_element_size(ids)); + const int si1 = ids->nb[1] / ggml_element_size(ids); + const int sis1 = nb12 / nb11; - { - const int64_t s11 = src1->nb[1] / ts_src1; - const int64_t s12 = src1->nb[2] / ts_src1; - const int64_t s13 = src1->nb[2] / ts_src1; - - if (use_native_mxfp4) { - quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, - ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); - } else { - quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, - ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + CUDA_CHECK(cudaGetLastError()); } - CUDA_CHECK(cudaGetLastError()); + + const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); + + const int64_t ne11_flat = ne12*n_expert_used; + const int64_t ne12_flat = 1; + const int64_t ne13_flat = 1; + + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[2] / ts_src1; + + if (use_native_mxfp4) { + quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + } else { + quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + } + CUDA_CHECK(cudaGetLastError()); + } + + void * ids_src1_ptr = ids_src1.get(); + ids_dst_ptr = ids_dst.get(); + expert_bounds_ptr = expert_bounds.get(); + src1_q8_1_ptr = src1_q8_1.get(); + + std::vector allocs; + // Store in allocation order; custom destructor will free in reverse (LIFO) + allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(ids_src1))); + allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(ids_dst))); + allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(expert_bounds))); + allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(src1_q8_1))); + + cache.add_entry( + src1, + ggml_cuda_cache::cache_entry{ + layout, + std::move(allocs), + ctx.node_count + }); } const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) : @@ -208,7 +274,7 @@ void ggml_cuda_mul_mat_q( // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. const mmq_args args = { - src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d, + src0_d, src0->type, (const int *) src1_q8_1_ptr, (int32_t *) ids_dst_ptr, (int32_t *) expert_bounds_ptr, dst_d, ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, ne02, ne02, s02, s12, s2, ne03, ne13, s03, s13, s3,