CUDA: cache intermediate tensors
This commit is contained in:
parent
c8a3798041
commit
23d04b313c
|
|
@ -5,6 +5,7 @@
|
|||
#include "ggml-cuda.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
|
|
@ -1218,6 +1219,104 @@ struct ggml_cuda_stream_context {
|
|||
}
|
||||
};
|
||||
|
||||
// cache to extend lifetimes of ggml_cuda_pool_alloc, ggml_cuda_pool expects memory to allocated in a LIFO order
|
||||
// hence this cache works like a stack
|
||||
struct ggml_cuda_cache {
|
||||
struct pool_alloc {
|
||||
ggml_cuda_pool * pool;
|
||||
void * ptr;
|
||||
size_t actual_size;
|
||||
|
||||
pool_alloc():
|
||||
pool{nullptr}
|
||||
, ptr{nullptr}
|
||||
, actual_size{0}
|
||||
{}
|
||||
|
||||
template<typename T>
|
||||
pool_alloc(ggml_cuda_pool_alloc<T> && other) {
|
||||
pool = other.pool;
|
||||
ptr = (void *)other.ptr;
|
||||
actual_size = other.actual_size;
|
||||
|
||||
other.ptr = nullptr;
|
||||
other.pool = nullptr;
|
||||
other.actual_size = 0;
|
||||
}
|
||||
|
||||
pool_alloc(pool_alloc && other) {
|
||||
pool = other.pool;
|
||||
ptr = (void *) other.ptr;
|
||||
actual_size = other.actual_size;
|
||||
other.ptr = nullptr;
|
||||
other.pool = nullptr;
|
||||
other.actual_size = 0;
|
||||
}
|
||||
|
||||
~pool_alloc() {
|
||||
if (ptr != nullptr) {
|
||||
pool->free(ptr, actual_size);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct cache_entry {
|
||||
int layout; // mmq_q8_1_ds_layout value
|
||||
std::vector<pool_alloc> pool_ptrs;
|
||||
size_t ttl_nodes{};
|
||||
|
||||
cache_entry() = default;
|
||||
|
||||
cache_entry(cache_entry && other) = default;
|
||||
cache_entry& operator=(cache_entry && other) = default;
|
||||
|
||||
cache_entry(const cache_entry &) = delete;
|
||||
cache_entry& operator=(const cache_entry &) = delete;
|
||||
|
||||
~cache_entry() {
|
||||
// Free pool allocations in reverse order (LIFO)
|
||||
while (!pool_ptrs.empty()) {
|
||||
pool_ptrs.pop_back();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void clear_cache() {
|
||||
remove_expired(std::numeric_limits<size_t>::max());
|
||||
entries.clear();
|
||||
}
|
||||
|
||||
void remove_expired(size_t node_count) {
|
||||
// max lifetime of cache entry - 10 nodes after
|
||||
while (!entries.empty() && entries.back().second.ttl_nodes + 10 <= node_count) {
|
||||
entries.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
cache_entry * find(const ggml_tensor * node, int layout) {
|
||||
for (auto & entry: entries) {
|
||||
if (entry.first == node && entry.second.layout == layout) {
|
||||
return &entry.second;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
~ggml_cuda_cache() {
|
||||
while (!entries.empty()) {
|
||||
entries.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
void add_entry(const ggml_tensor * node, cache_entry && entry) {
|
||||
entries.emplace_back(node, std::move(entry));
|
||||
}
|
||||
|
||||
std::vector<std::pair<const ggml_tensor *, cache_entry>> entries;
|
||||
};
|
||||
|
||||
|
||||
|
||||
struct ggml_backend_cuda_context {
|
||||
int device;
|
||||
std::string name;
|
||||
|
|
@ -1229,6 +1328,7 @@ struct ggml_backend_cuda_context {
|
|||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
int curr_stream_no = 0;
|
||||
size_t node_count = 0;
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
|
|
@ -1266,6 +1366,7 @@ struct ggml_backend_cuda_context {
|
|||
|
||||
// pool
|
||||
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
|
||||
std::unique_ptr<ggml_cuda_cache> caches[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]{{}};
|
||||
|
||||
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
|
||||
|
||||
|
|
@ -1276,6 +1377,27 @@ struct ggml_backend_cuda_context {
|
|||
return *pools[device][curr_stream_no];
|
||||
}
|
||||
|
||||
ggml_cuda_cache & cache(int device, int stream) {
|
||||
if (caches[device][stream] == nullptr) {
|
||||
caches[device][stream] = std::unique_ptr<ggml_cuda_cache>(new ggml_cuda_cache());
|
||||
}
|
||||
return *caches[device][stream];
|
||||
}
|
||||
|
||||
ggml_cuda_cache & cache() {
|
||||
return cache(device, curr_stream_no);
|
||||
}
|
||||
|
||||
void clear_cache() {
|
||||
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
||||
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
||||
if (caches[i][j]) {
|
||||
caches[i][j]->clear_cache();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cuda_pool & pool() {
|
||||
return pool(device);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3232,6 +3232,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
ggml_cuda_concurrent_event * concurrent_event = nullptr;
|
||||
bool should_launch_concurrent_events = false;
|
||||
|
||||
cuda_ctx->clear_cache();
|
||||
|
||||
const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
|
||||
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
|
||||
concurrent_event = &stream_ctx.concurrent_events[node];
|
||||
|
|
@ -3662,6 +3664,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
}
|
||||
GGML_ASSERT(ok);
|
||||
|
||||
// Increment node counter and expire old cache entries
|
||||
cuda_ctx->node_count++;
|
||||
cuda_ctx->cache().remove_expired(cuda_ctx->node_count);
|
||||
|
||||
if (!is_concurrent_event_active) {
|
||||
try_launch_concurrent_event(node);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
#include "mmq.cuh"
|
||||
#include "quantize.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
|
@ -118,25 +119,53 @@ void ggml_cuda_mul_mat_q(
|
|||
// TODO: tighter pool buffer size vs q8 path
|
||||
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
|
||||
|
||||
ggml_cuda_cache & cache = ctx.cache();
|
||||
|
||||
if (!ids) {
|
||||
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
|
||||
|
||||
{
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
if (use_native_mxfp4) {
|
||||
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
|
||||
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
const int layout = use_native_mxfp4 ? -1 : mmq_get_q8_1_ds_layout(src0->type);
|
||||
|
||||
void * src1_ptr = nullptr;
|
||||
ggml_cuda_cache::cache_entry * entry = cache.find(src1, layout);
|
||||
if (entry != nullptr) {
|
||||
GGML_ASSERT(entry->pool_ptrs.size() == 1);
|
||||
size_t expected_size = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
GGML_ASSERT(entry->pool_ptrs[0].actual_size >= expected_size);
|
||||
src1_ptr = entry->pool_ptrs[0].ptr;
|
||||
GGML_ASSERT(src1_ptr != nullptr);
|
||||
} else {
|
||||
|
||||
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
|
||||
{
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
if (use_native_mxfp4) {
|
||||
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
|
||||
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
src1_ptr = src1_q8_1.get();
|
||||
|
||||
std::vector<ggml_cuda_cache::pool_alloc> allocs;
|
||||
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(src1_q8_1)));
|
||||
|
||||
cache.add_entry(
|
||||
src1,
|
||||
ggml_cuda_cache::cache_entry{
|
||||
layout,
|
||||
std::move(allocs),
|
||||
ctx.node_count
|
||||
});
|
||||
}
|
||||
|
||||
// Stride depends on quantization format
|
||||
|
|
@ -148,7 +177,7 @@ void ggml_cuda_mul_mat_q(
|
|||
const int64_t s13 = ne12*s12;
|
||||
|
||||
const mmq_args args = {
|
||||
src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
|
||||
src0_d, src0->type, (const int *) src1_ptr, nullptr, nullptr, dst_d,
|
||||
ne00, ne01, ne1, s01, ne11, s1,
|
||||
ne02, ne12, s02, s12, s2,
|
||||
ne03, ne13, s03, s13, s3,
|
||||
|
|
@ -165,41 +194,78 @@ void ggml_cuda_mul_mat_q(
|
|||
const int64_t ne_get_rows = ne12 * n_expert_used;
|
||||
GGML_ASSERT(ne1 == n_expert_used);
|
||||
|
||||
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
|
||||
const int layout = use_native_mxfp4 ? -1 : mmq_get_q8_1_ds_layout(src0->type);
|
||||
|
||||
{
|
||||
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
|
||||
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
||||
const int sis1 = nb12 / nb11;
|
||||
void * ids_dst_ptr = nullptr;
|
||||
void * expert_bounds_ptr = nullptr;
|
||||
void * src1_q8_1_ptr = nullptr;
|
||||
|
||||
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
ggml_cuda_cache::cache_entry * entry = cache.find(src1, layout);
|
||||
if (entry != nullptr) {
|
||||
GGML_ASSERT(entry->pool_ptrs.size() == 4);
|
||||
ids_dst_ptr = entry->pool_ptrs[1].ptr;
|
||||
expert_bounds_ptr = entry->pool_ptrs[2].ptr;
|
||||
src1_q8_1_ptr = entry->pool_ptrs[3].ptr;
|
||||
|
||||
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
|
||||
size_t expected_q8_1_size = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
GGML_ASSERT(entry->pool_ptrs[3].actual_size >= expected_q8_1_size);
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
|
||||
|
||||
const int64_t ne11_flat = ne12*n_expert_used;
|
||||
const int64_t ne12_flat = 1;
|
||||
const int64_t ne13_flat = 1;
|
||||
{
|
||||
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
|
||||
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
||||
const int sis1 = nb12 / nb11;
|
||||
|
||||
{
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
|
||||
if (use_native_mxfp4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
|
||||
|
||||
const int64_t ne11_flat = ne12*n_expert_used;
|
||||
const int64_t ne12_flat = 1;
|
||||
const int64_t ne13_flat = 1;
|
||||
|
||||
{
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
|
||||
if (use_native_mxfp4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void * ids_src1_ptr = ids_src1.get();
|
||||
ids_dst_ptr = ids_dst.get();
|
||||
expert_bounds_ptr = expert_bounds.get();
|
||||
src1_q8_1_ptr = src1_q8_1.get();
|
||||
|
||||
std::vector<ggml_cuda_cache::pool_alloc> allocs;
|
||||
// Store in allocation order; custom destructor will free in reverse (LIFO)
|
||||
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(ids_src1)));
|
||||
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(ids_dst)));
|
||||
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(expert_bounds)));
|
||||
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(src1_q8_1)));
|
||||
|
||||
cache.add_entry(
|
||||
src1,
|
||||
ggml_cuda_cache::cache_entry{
|
||||
layout,
|
||||
std::move(allocs),
|
||||
ctx.node_count
|
||||
});
|
||||
}
|
||||
|
||||
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
|
||||
|
|
@ -208,7 +274,7 @@ void ggml_cuda_mul_mat_q(
|
|||
|
||||
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
||||
const mmq_args args = {
|
||||
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
|
||||
src0_d, src0->type, (const int *) src1_q8_1_ptr, (int32_t *) ids_dst_ptr, (int32_t *) expert_bounds_ptr, dst_d,
|
||||
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
|
||||
ne02, ne02, s02, s12, s2,
|
||||
ne03, ne13, s03, s13, s3,
|
||||
|
|
|
|||
Loading…
Reference in New Issue