ported residual changes about grad_checkpointing

This commit is contained in:
Salvatore Rossitto 2026-03-12 12:22:12 +01:00
parent 70730e8d28
commit 22277e3cbf
5 changed files with 171 additions and 17 deletions

View File

@ -73,6 +73,7 @@ int main(int argc, char ** argv) {
/*get_opt_pars =*/common_opt_lr_pars,
/*get_opt_pars_ud =*/&params.lr,
/*optimizer_type =*/params.optimizer,
/*grad_checkpoint_interval =*/params.grad_checkpoint_interval,
};
llama_opt_init(ctx, model, lopt_params);

View File

@ -126,6 +126,13 @@ extern "C" {
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
// Gradient checkpointing: keep the output of every Nth forward node alive through
// the backward pass so the allocator cannot reuse its memory for other tensors.
// This trades compute for VRAM — intermediate activations between checkpoints are
// freed and recomputed during the backward pass by the existing graph structure.
// Set to 0 (default) to disable. A value of ~3264 cuts activation VRAM by ~50%.
int32_t grad_checkpoint_interval;
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
enum ggml_opt_optimizer_type optimizer;
};

View File

@ -58,10 +58,13 @@ struct ggml_opt_context {
std::vector<struct ggml_tensor *> grad_accs;
std::vector<struct ggml_tensor *> grad_m;
std::vector<struct ggml_tensor *> grad_v;
std::vector<ggml_backend_buffer_t> bufs_momenta; // per-param moment buffers (one per param node)
std::vector<struct ggml_context *> ctxs_momenta; // corresponding ggml contexts (keep alive for tensor metadata)
int64_t iter = 1;
int32_t opt_period = 1;
int32_t opt_i = 0;
int32_t grad_checkpoint_interval = 0;
bool loss_per_datapoint = false;
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
@ -254,9 +257,10 @@ struct ggml_opt_params ggml_opt_default_params(
/*loss_type =*/ loss_type,
/*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
/*opt_period =*/ 1,
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
/*get_opt_pars_ud =*/ nullptr,
/*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
/*get_opt_pars_ud =*/ nullptr,
/*grad_checkpoint_interval =*/ 0,
/*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
};
}
@ -476,8 +480,23 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
for (int i = 0; i < n_nodes; ++i) {
ggml_tensor * node = opt_ctx->gf->nodes[i];
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
// Allocate moments on the same buffer type as the param tensor so
// the ADAMW op runs on the correct backend (avoids cross-device mismatch
// when some LoRA tensors are on CPU and others on GPU with partial offload).
ggml_backend_buffer_type_t param_buft = node->buffer
? ggml_backend_buffer_get_type(node->buffer)
: ggml_backend_cpu_buffer_type();
// Allocate a tiny context + buffer for this pair of moment tensors.
const size_t sz = 2 * ggml_tensor_overhead();
struct ggml_init_params mip = { sz, nullptr, true };
struct ggml_context * mctx = ggml_init(mip);
opt_ctx->grad_m[i] = ggml_new_tensor(mctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
opt_ctx->grad_v[i] = ggml_new_tensor(mctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
ggml_backend_buffer_t mbuf = ggml_backend_alloc_ctx_tensors_from_buft(mctx, param_buft);
ggml_backend_buffer_clear(mbuf, 0);
opt_ctx->bufs_momenta.push_back(mbuf);
opt_ctx->ctxs_momenta.push_back(mctx); // keep alive for tensor metadata
} else {
opt_ctx->grad_m[i] = nullptr;
opt_ctx->grad_v[i] = nullptr;
@ -486,6 +505,31 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
}
}
// Gradient checkpointing: mark every Nth forward node as OUTPUT so the allocator
// keeps its memory alive through the backward pass. The backward graph already
// contains the forward ops (gb_grad is a superset of gf), so the checkpointed
// activations are naturally available for backward matmuls without recomputation.
// This prevents the allocator from aliasing those buffers to later ops, cutting
// peak activation VRAM at the cost of slightly larger static allocation.
if (opt_ctx->grad_checkpoint_interval > 0) {
const int interval = opt_ctx->grad_checkpoint_interval;
const int n_fwd = opt_ctx->gf->n_nodes;
int ckpt_count = 0;
for (int i = interval - 1; i < n_fwd; i += interval) {
struct ggml_tensor * node = opt_ctx->gf->nodes[i];
// Only checkpoint F32 compute nodes — skip I32 index tensors and already-output nodes.
if (node->type != GGML_TYPE_F32) continue;
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) continue;
if (node->flags & GGML_TENSOR_FLAG_INPUT) continue;
node->flags |= GGML_TENSOR_FLAG_OUTPUT;
ckpt_count++;
}
if (ckpt_count > 0) {
GGML_LOG_DEBUG("%s: gradient checkpointing: marked %d/%d nodes as persistent (interval=%d)\n",
__func__, ckpt_count, n_fwd, interval);
}
}
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
@ -556,10 +600,11 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
result->build_type_alloc = params.build_type;
result->inputs = params.inputs;
result->outputs = params.outputs;
result->opt_period = params.opt_period;
result->get_opt_pars = params.get_opt_pars;
result->get_opt_pars_ud = params.get_opt_pars_ud;
result->optimizer = params.optimizer;
result->opt_period = params.opt_period;
result->grad_checkpoint_interval = params.grad_checkpoint_interval;
result->get_opt_pars = params.get_opt_pars;
result->get_opt_pars_ud = params.get_opt_pars_ud;
result->optimizer = params.optimizer;
GGML_ASSERT(result->opt_period >= 1);
@ -588,6 +633,12 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
}
ggml_backend_buffer_free(opt_ctx->buf_static);
ggml_backend_buffer_free(opt_ctx->buf_cpu);
for (ggml_backend_buffer_t buf : opt_ctx->bufs_momenta) {
ggml_backend_buffer_free(buf);
}
for (struct ggml_context * ctx : opt_ctx->ctxs_momenta) {
ggml_free(ctx);
}
ggml_free(opt_ctx->ctx_static);
ggml_free(opt_ctx->ctx_cpu);
delete opt_ctx;

View File

@ -1552,6 +1552,12 @@ extern "C" {
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
enum ggml_opt_optimizer_type optimizer_type;
// Gradient checkpointing: mark every Nth forward graph node as persistent so the
// allocator cannot reuse its memory during backward. Reduces peak activation VRAM
// at the cost of ~0 extra compute (activations are kept, not recomputed).
// Set to 0 (default) to disable. Good values: 3264 nodes ≈ every 12 transformer layers.
int32_t grad_checkpoint_interval;
};
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);

View File

@ -2618,11 +2618,71 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
GGML_ASSERT(n_batch % n_ubatch == 0);
// Recreate the scheduler and gf_res_prev with a training-inflated graph size before
// creating opt_ctx, so opt_ctx captures the new (larger) scheduler pointer.
// The backward graph (gb_grad) duplicates gf and adds ~2-3x more nodes+leafs;
// gb_opt adds optimizer step nodes on top.
//
// We measure the actual training forward graph node count at n_ubatch here,
// then multiply by 4 to cover gf + gb_grad + gb_opt. This is exact for any
// model size — no magic constant needed.
{
uint32_t train_fwd_nodes = 0;
// Build a real training-ubatch forward graph in split-only mode (no buffer realloc)
// so we can count its actual nodes. Fall back to n_tensors formula if it fails.
if (memory) {
auto mctx_tmp = memory->init_full();
if (mctx_tmp) {
// graph_reserve() uses gf_res_reserve to build the graph, so both
// must be large enough to hold the training forward graph.
// Use 16x n_tensors as a generous temporary cap for the measurement pass.
const uint32_t tmp_cap = std::max<uint32_t>(4096u, 16u * model->n_tensors());
gf_res_prev.reset(new llm_graph_result(tmp_cap));
gf_res_reserve.reset(new llm_graph_result(tmp_cap));
// split_only=true: only splits the graph, doesn't reallocate compute buffers
auto * gf_train = graph_reserve(n_ubatch, 1, n_ubatch, mctx_tmp.get(), /*split_only=*/true);
if (gf_train) {
train_fwd_nodes = (uint32_t)ggml_graph_n_nodes(gf_train);
LLAMA_LOG_INFO("%s: measured training graph nodes = %u (n_ubatch=%u)\n",
__func__, train_fwd_nodes, n_ubatch);
}
}
}
if (train_fwd_nodes == 0) {
// Fallback: use n_tensors formula
train_fwd_nodes = std::max<uint32_t>(1024u, 8u * model->n_tensors());
LLAMA_LOG_WARN("%s: could not measure training graph, using fallback nodes=%u\n",
__func__, train_fwd_nodes);
}
// gf + gb_grad + gb_opt each need ~train_fwd_nodes; multiply by 4 for safety headroom.
// Multiply by 2 again for the scheduler's n_nodes + n_leafs check.
const int64_t inflated = (int64_t)std::max<uint32_t>(train_fwd_nodes, 1024u) * 4;
const int64_t sched_size = inflated * 2;
// Both gf_res_prev and gf_res_reserve are used to build forward graphs
// (graph_reserve uses gf_res_reserve; opt_epoch_iter uses gf_res_prev).
// Both must have capacity for the full backward graph.
gf_res_prev.reset(new llm_graph_result(inflated));
gf_res_reserve.reset(new llm_graph_result(inflated));
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(),
sched_size, cparams.pipeline_parallel, cparams.op_offload));
// Suppress the next sched_reserve() call so that llama_decode() during GRPO inference
// steps does NOT replace the training sched with a smaller inference sched.
// opt_ctx->backend_sched stores a raw pointer to sched.get(); replacing sched while
// opt_ctx is alive would leave that pointer dangling and crash on the next opt_epoch.
sched_need_reserve = false;
LLAMA_LOG_INFO("%s: training graph capacity = %lld (train_fwd_nodes=%u x4)\n",
__func__, (long long)inflated, train_fwd_nodes);
}
ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
opt_params.opt_period = n_batch / n_ubatch;
opt_params.get_opt_pars = lopt_params.get_opt_pars;
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
opt_params.optimizer = lopt_params.optimizer_type;
opt_params.opt_period = n_batch / n_ubatch;
opt_params.get_opt_pars = lopt_params.get_opt_pars;
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
opt_params.optimizer = lopt_params.optimizer_type;
opt_params.grad_checkpoint_interval = lopt_params.grad_checkpoint_interval;
opt_ctx = ggml_opt_init(opt_params);
llama_opt_param_filter param_filter = lopt_params.param_filter;
@ -2706,6 +2766,8 @@ void llama_context::opt_epoch_iter(
};
uint32_t pos_batch = 0;
static bool timings_printed = false; // print per-ubatch timings only for the first window
struct ggml_context * ctx_compute_opt = nullptr;
do {
const auto & ubatch = mctx->get_ubatch();
@ -2718,26 +2780,38 @@ void llama_context::opt_epoch_iter(
auto * res = gf_res_prev.get();
const int64_t t0_build = ggml_time_ms();
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
res->reset();
auto * gf = model.build_graph(gparams);
struct ggml_context * ctx_compute_opt;
{
// Allocate the tensor metadata context once, then reset it each iteration.
// ggml_reset() is much cheaper than ggml_free()+ggml_init() — it just resets the
// allocation pointer without freeing/reallocating the backing memory buffer.
if (!ctx_compute_opt) {
const size_t size_gf = ggml_graph_size(gf);
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 3*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
struct ggml_init_params params = {
/*.mem_size =*/ size_meta,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ctx_compute_opt = ggml_init(params);
if (!timings_printed) {
LLAMA_LOG_INFO("%s: [timing] graph capacity=%zu n_nodes=%d size_meta=%.1fMB\n", __func__,
size_gf, ggml_graph_n_nodes(gf), (double)size_meta / (1024*1024));
}
} else {
ggml_reset(ctx_compute_opt);
}
const int64_t t1_alloc = ggml_time_ms();
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits());
ggml_opt_alloc(opt_ctx, train);
const int64_t t2_inputs = ggml_time_ms();
res->set_inputs(&ubatch);
{
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
@ -2753,14 +2827,29 @@ void llama_context::opt_epoch_iter(
ggml_backend_tensor_set(labels, &reward_scale, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
}
}
const int64_t t3_eval = ggml_time_ms();
ggml_opt_eval(opt_ctx, result);
const int64_t t4_done = ggml_time_ms();
if (!timings_printed) {
LLAMA_LOG_INFO("%s: [timing] build=%" PRId64 "ms alloc=%" PRId64 "ms inputs=%" PRId64 "ms eval=%" PRId64 "ms total=%" PRId64 "ms\n",
__func__,
t1_alloc - t0_build,
t2_inputs - t1_alloc,
t3_eval - t2_inputs,
t4_done - t3_eval,
t4_done - t0_build);
timings_printed = true;
}
if (callback) {
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
}
ggml_free(ctx_compute_opt);
pos_batch += ubatch.n_tokens;
} while (mctx->next());
ggml_free(ctx_compute_opt);
}
}