From 22277e3cbfb47fc85dadc7c8ce6c83a10977c503 Mon Sep 17 00:00:00 2001 From: Salvatore Rossitto Date: Thu, 12 Mar 2026 12:22:12 +0100 Subject: [PATCH] ported residual changes about grad_checkpointing --- examples/training/finetune.cpp | 1 + ggml/include/ggml-opt.h | 7 +++ ggml/src/ggml-opt.cpp | 69 +++++++++++++++++++--- include/llama.h | 6 ++ src/llama-context.cpp | 105 ++++++++++++++++++++++++++++++--- 5 files changed, 171 insertions(+), 17 deletions(-) diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp index dd58f9418e..88d3db741d 100644 --- a/examples/training/finetune.cpp +++ b/examples/training/finetune.cpp @@ -73,6 +73,7 @@ int main(int argc, char ** argv) { /*get_opt_pars =*/common_opt_lr_pars, /*get_opt_pars_ud =*/¶ms.lr, /*optimizer_type =*/params.optimizer, + /*grad_checkpoint_interval =*/params.grad_checkpoint_interval, }; llama_opt_init(ctx, model, lopt_params); diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index 60774575f0..cac543c02d 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -126,6 +126,13 @@ extern "C" { ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters void * get_opt_pars_ud; // userdata for calculating optimizer parameters + // Gradient checkpointing: keep the output of every Nth forward node alive through + // the backward pass so the allocator cannot reuse its memory for other tensors. + // This trades compute for VRAM — intermediate activations between checkpoints are + // freed and recomputed during the backward pass by the existing graph structure. + // Set to 0 (default) to disable. A value of ~32–64 cuts activation VRAM by ~50%. + int32_t grad_checkpoint_interval; + // only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor enum ggml_opt_optimizer_type optimizer; }; diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index e87fc79c25..8be90c8944 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -58,10 +58,13 @@ struct ggml_opt_context { std::vector grad_accs; std::vector grad_m; std::vector grad_v; + std::vector bufs_momenta; // per-param moment buffers (one per param node) + std::vector ctxs_momenta; // corresponding ggml contexts (keep alive for tensor metadata) int64_t iter = 1; int32_t opt_period = 1; int32_t opt_i = 0; + int32_t grad_checkpoint_interval = 0; bool loss_per_datapoint = false; ggml_opt_get_optimizer_params get_opt_pars = nullptr; @@ -254,9 +257,10 @@ struct ggml_opt_params ggml_opt_default_params( /*loss_type =*/ loss_type, /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT, /*opt_period =*/ 1, - /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params, - /*get_opt_pars_ud =*/ nullptr, - /*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW, + /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params, + /*get_opt_pars_ud =*/ nullptr, + /*grad_checkpoint_interval =*/ 0, + /*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW, }; } @@ -476,8 +480,23 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) { for (int i = 0; i < n_nodes; ++i) { ggml_tensor * node = opt_ctx->gf->nodes[i]; if (node->flags & GGML_TENSOR_FLAG_PARAM) { - opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); - opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + // Allocate moments on the same buffer type as the param tensor so + // the ADAMW op runs on the correct backend (avoids cross-device mismatch + // when some LoRA tensors are on CPU and others on GPU with partial offload). + ggml_backend_buffer_type_t param_buft = node->buffer + ? ggml_backend_buffer_get_type(node->buffer) + : ggml_backend_cpu_buffer_type(); + + // Allocate a tiny context + buffer for this pair of moment tensors. + const size_t sz = 2 * ggml_tensor_overhead(); + struct ggml_init_params mip = { sz, nullptr, true }; + struct ggml_context * mctx = ggml_init(mip); + opt_ctx->grad_m[i] = ggml_new_tensor(mctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + opt_ctx->grad_v[i] = ggml_new_tensor(mctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + ggml_backend_buffer_t mbuf = ggml_backend_alloc_ctx_tensors_from_buft(mctx, param_buft); + ggml_backend_buffer_clear(mbuf, 0); + opt_ctx->bufs_momenta.push_back(mbuf); + opt_ctx->ctxs_momenta.push_back(mctx); // keep alive for tensor metadata } else { opt_ctx->grad_m[i] = nullptr; opt_ctx->grad_v[i] = nullptr; @@ -486,6 +505,31 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) { } } + // Gradient checkpointing: mark every Nth forward node as OUTPUT so the allocator + // keeps its memory alive through the backward pass. The backward graph already + // contains the forward ops (gb_grad is a superset of gf), so the checkpointed + // activations are naturally available for backward matmuls without recomputation. + // This prevents the allocator from aliasing those buffers to later ops, cutting + // peak activation VRAM at the cost of slightly larger static allocation. + if (opt_ctx->grad_checkpoint_interval > 0) { + const int interval = opt_ctx->grad_checkpoint_interval; + const int n_fwd = opt_ctx->gf->n_nodes; + int ckpt_count = 0; + for (int i = interval - 1; i < n_fwd; i += interval) { + struct ggml_tensor * node = opt_ctx->gf->nodes[i]; + // Only checkpoint F32 compute nodes — skip I32 index tensors and already-output nodes. + if (node->type != GGML_TYPE_F32) continue; + if (node->flags & GGML_TENSOR_FLAG_OUTPUT) continue; + if (node->flags & GGML_TENSOR_FLAG_INPUT) continue; + node->flags |= GGML_TENSOR_FLAG_OUTPUT; + ckpt_count++; + } + if (ckpt_count > 0) { + GGML_LOG_DEBUG("%s: gradient checkpointing: marked %d/%d nodes as persistent (interval=%d)\n", + __func__, ckpt_count, n_fwd, interval); + } + } + // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true); ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data()); @@ -556,10 +600,11 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { result->build_type_alloc = params.build_type; result->inputs = params.inputs; result->outputs = params.outputs; - result->opt_period = params.opt_period; - result->get_opt_pars = params.get_opt_pars; - result->get_opt_pars_ud = params.get_opt_pars_ud; - result->optimizer = params.optimizer; + result->opt_period = params.opt_period; + result->grad_checkpoint_interval = params.grad_checkpoint_interval; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + result->optimizer = params.optimizer; GGML_ASSERT(result->opt_period >= 1); @@ -588,6 +633,12 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) { } ggml_backend_buffer_free(opt_ctx->buf_static); ggml_backend_buffer_free(opt_ctx->buf_cpu); + for (ggml_backend_buffer_t buf : opt_ctx->bufs_momenta) { + ggml_backend_buffer_free(buf); + } + for (struct ggml_context * ctx : opt_ctx->ctxs_momenta) { + ggml_free(ctx); + } ggml_free(opt_ctx->ctx_static); ggml_free(opt_ctx->ctx_cpu); delete opt_ctx; diff --git a/include/llama.h b/include/llama.h index 0bf8ead384..6a3a1ebe38 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1552,6 +1552,12 @@ extern "C" { void * get_opt_pars_ud; // userdata for calculating optimizer parameters enum ggml_opt_optimizer_type optimizer_type; + + // Gradient checkpointing: mark every Nth forward graph node as persistent so the + // allocator cannot reuse its memory during backward. Reduces peak activation VRAM + // at the cost of ~0 extra compute (activations are kept, not recomputed). + // Set to 0 (default) to disable. Good values: 32–64 nodes ≈ every 1–2 transformer layers. + int32_t grad_checkpoint_interval; }; LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9f67d47b50..ba98acd403 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2618,11 +2618,71 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); GGML_ASSERT(n_batch % n_ubatch == 0); + // Recreate the scheduler and gf_res_prev with a training-inflated graph size before + // creating opt_ctx, so opt_ctx captures the new (larger) scheduler pointer. + // The backward graph (gb_grad) duplicates gf and adds ~2-3x more nodes+leafs; + // gb_opt adds optimizer step nodes on top. + // + // We measure the actual training forward graph node count at n_ubatch here, + // then multiply by 4 to cover gf + gb_grad + gb_opt. This is exact for any + // model size — no magic constant needed. + { + uint32_t train_fwd_nodes = 0; + + // Build a real training-ubatch forward graph in split-only mode (no buffer realloc) + // so we can count its actual nodes. Fall back to n_tensors formula if it fails. + if (memory) { + auto mctx_tmp = memory->init_full(); + if (mctx_tmp) { + // graph_reserve() uses gf_res_reserve to build the graph, so both + // must be large enough to hold the training forward graph. + // Use 16x n_tensors as a generous temporary cap for the measurement pass. + const uint32_t tmp_cap = std::max(4096u, 16u * model->n_tensors()); + gf_res_prev.reset(new llm_graph_result(tmp_cap)); + gf_res_reserve.reset(new llm_graph_result(tmp_cap)); + // split_only=true: only splits the graph, doesn't reallocate compute buffers + auto * gf_train = graph_reserve(n_ubatch, 1, n_ubatch, mctx_tmp.get(), /*split_only=*/true); + if (gf_train) { + train_fwd_nodes = (uint32_t)ggml_graph_n_nodes(gf_train); + LLAMA_LOG_INFO("%s: measured training graph nodes = %u (n_ubatch=%u)\n", + __func__, train_fwd_nodes, n_ubatch); + } + } + } + + if (train_fwd_nodes == 0) { + // Fallback: use n_tensors formula + train_fwd_nodes = std::max(1024u, 8u * model->n_tensors()); + LLAMA_LOG_WARN("%s: could not measure training graph, using fallback nodes=%u\n", + __func__, train_fwd_nodes); + } + + // gf + gb_grad + gb_opt each need ~train_fwd_nodes; multiply by 4 for safety headroom. + // Multiply by 2 again for the scheduler's n_nodes + n_leafs check. + const int64_t inflated = (int64_t)std::max(train_fwd_nodes, 1024u) * 4; + const int64_t sched_size = inflated * 2; + // Both gf_res_prev and gf_res_reserve are used to build forward graphs + // (graph_reserve uses gf_res_reserve; opt_epoch_iter uses gf_res_prev). + // Both must have capacity for the full backward graph. + gf_res_prev.reset(new llm_graph_result(inflated)); + gf_res_reserve.reset(new llm_graph_result(inflated)); + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), + sched_size, cparams.pipeline_parallel, cparams.op_offload)); + // Suppress the next sched_reserve() call so that llama_decode() during GRPO inference + // steps does NOT replace the training sched with a smaller inference sched. + // opt_ctx->backend_sched stores a raw pointer to sched.get(); replacing sched while + // opt_ctx is alive would leave that pointer dangling and crash on the next opt_epoch. + sched_need_reserve = false; + LLAMA_LOG_INFO("%s: training graph capacity = %lld (train_fwd_nodes=%u x4)\n", + __func__, (long long)inflated, train_fwd_nodes); + } + ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); - opt_params.opt_period = n_batch / n_ubatch; - opt_params.get_opt_pars = lopt_params.get_opt_pars; - opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; - opt_params.optimizer = lopt_params.optimizer_type; + opt_params.opt_period = n_batch / n_ubatch; + opt_params.get_opt_pars = lopt_params.get_opt_pars; + opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; + opt_params.optimizer = lopt_params.optimizer_type; + opt_params.grad_checkpoint_interval = lopt_params.grad_checkpoint_interval; opt_ctx = ggml_opt_init(opt_params); llama_opt_param_filter param_filter = lopt_params.param_filter; @@ -2706,6 +2766,8 @@ void llama_context::opt_epoch_iter( }; uint32_t pos_batch = 0; + static bool timings_printed = false; // print per-ubatch timings only for the first window + struct ggml_context * ctx_compute_opt = nullptr; do { const auto & ubatch = mctx->get_ubatch(); @@ -2718,26 +2780,38 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); + const int64_t t0_build = ggml_time_ms(); const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); res->reset(); auto * gf = model.build_graph(gparams); - struct ggml_context * ctx_compute_opt; - { + // Allocate the tensor metadata context once, then reset it each iteration. + // ggml_reset() is much cheaper than ggml_free()+ggml_init() — it just resets the + // allocation pointer without freeing/reallocating the backing memory buffer. + if (!ctx_compute_opt) { const size_t size_gf = ggml_graph_size(gf); - const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); + const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 3*ggml_graph_overhead_custom(size_gf, /*grads = */ true); struct ggml_init_params params = { /*.mem_size =*/ size_meta, /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; ctx_compute_opt = ggml_init(params); + if (!timings_printed) { + LLAMA_LOG_INFO("%s: [timing] graph capacity=%zu n_nodes=%d size_meta=%.1fMB\n", __func__, + size_gf, ggml_graph_n_nodes(gf), (double)size_meta / (1024*1024)); + } + } else { + ggml_reset(ctx_compute_opt); } + + const int64_t t1_alloc = ggml_time_ms(); ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); + const int64_t t2_inputs = ggml_time_ms(); res->set_inputs(&ubatch); { struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); @@ -2753,14 +2827,29 @@ void llama_context::opt_epoch_iter( ggml_backend_tensor_set(labels, &reward_scale, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); } } + + const int64_t t3_eval = ggml_time_ms(); ggml_opt_eval(opt_ctx, result); + + const int64_t t4_done = ggml_time_ms(); + if (!timings_printed) { + LLAMA_LOG_INFO("%s: [timing] build=%" PRId64 "ms alloc=%" PRId64 "ms inputs=%" PRId64 "ms eval=%" PRId64 "ms total=%" PRId64 "ms\n", + __func__, + t1_alloc - t0_build, + t2_inputs - t1_alloc, + t3_eval - t2_inputs, + t4_done - t3_eval, + t4_done - t0_build); + timings_printed = true; + } + if (callback) { callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); } - ggml_free(ctx_compute_opt); pos_batch += ubatch.n_tokens; } while (mctx->next()); + ggml_free(ctx_compute_opt); } }