ported residual changes about grad_checkpointing
This commit is contained in:
parent
70730e8d28
commit
22277e3cbf
|
|
@ -73,6 +73,7 @@ int main(int argc, char ** argv) {
|
|||
/*get_opt_pars =*/common_opt_lr_pars,
|
||||
/*get_opt_pars_ud =*/¶ms.lr,
|
||||
/*optimizer_type =*/params.optimizer,
|
||||
/*grad_checkpoint_interval =*/params.grad_checkpoint_interval,
|
||||
};
|
||||
llama_opt_init(ctx, model, lopt_params);
|
||||
|
||||
|
|
|
|||
|
|
@ -126,6 +126,13 @@ extern "C" {
|
|||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
|
||||
// Gradient checkpointing: keep the output of every Nth forward node alive through
|
||||
// the backward pass so the allocator cannot reuse its memory for other tensors.
|
||||
// This trades compute for VRAM — intermediate activations between checkpoints are
|
||||
// freed and recomputed during the backward pass by the existing graph structure.
|
||||
// Set to 0 (default) to disable. A value of ~32–64 cuts activation VRAM by ~50%.
|
||||
int32_t grad_checkpoint_interval;
|
||||
|
||||
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
||||
enum ggml_opt_optimizer_type optimizer;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -58,10 +58,13 @@ struct ggml_opt_context {
|
|||
std::vector<struct ggml_tensor *> grad_accs;
|
||||
std::vector<struct ggml_tensor *> grad_m;
|
||||
std::vector<struct ggml_tensor *> grad_v;
|
||||
std::vector<ggml_backend_buffer_t> bufs_momenta; // per-param moment buffers (one per param node)
|
||||
std::vector<struct ggml_context *> ctxs_momenta; // corresponding ggml contexts (keep alive for tensor metadata)
|
||||
|
||||
int64_t iter = 1;
|
||||
int32_t opt_period = 1;
|
||||
int32_t opt_i = 0;
|
||||
int32_t grad_checkpoint_interval = 0;
|
||||
bool loss_per_datapoint = false;
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
||||
|
|
@ -254,9 +257,10 @@ struct ggml_opt_params ggml_opt_default_params(
|
|||
/*loss_type =*/ loss_type,
|
||||
/*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
|
||||
/*opt_period =*/ 1,
|
||||
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
|
||||
/*get_opt_pars_ud =*/ nullptr,
|
||||
/*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
||||
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
|
||||
/*get_opt_pars_ud =*/ nullptr,
|
||||
/*grad_checkpoint_interval =*/ 0,
|
||||
/*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -476,8 +480,23 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
|||
for (int i = 0; i < n_nodes; ++i) {
|
||||
ggml_tensor * node = opt_ctx->gf->nodes[i];
|
||||
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
// Allocate moments on the same buffer type as the param tensor so
|
||||
// the ADAMW op runs on the correct backend (avoids cross-device mismatch
|
||||
// when some LoRA tensors are on CPU and others on GPU with partial offload).
|
||||
ggml_backend_buffer_type_t param_buft = node->buffer
|
||||
? ggml_backend_buffer_get_type(node->buffer)
|
||||
: ggml_backend_cpu_buffer_type();
|
||||
|
||||
// Allocate a tiny context + buffer for this pair of moment tensors.
|
||||
const size_t sz = 2 * ggml_tensor_overhead();
|
||||
struct ggml_init_params mip = { sz, nullptr, true };
|
||||
struct ggml_context * mctx = ggml_init(mip);
|
||||
opt_ctx->grad_m[i] = ggml_new_tensor(mctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
opt_ctx->grad_v[i] = ggml_new_tensor(mctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
ggml_backend_buffer_t mbuf = ggml_backend_alloc_ctx_tensors_from_buft(mctx, param_buft);
|
||||
ggml_backend_buffer_clear(mbuf, 0);
|
||||
opt_ctx->bufs_momenta.push_back(mbuf);
|
||||
opt_ctx->ctxs_momenta.push_back(mctx); // keep alive for tensor metadata
|
||||
} else {
|
||||
opt_ctx->grad_m[i] = nullptr;
|
||||
opt_ctx->grad_v[i] = nullptr;
|
||||
|
|
@ -486,6 +505,31 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
|||
}
|
||||
}
|
||||
|
||||
// Gradient checkpointing: mark every Nth forward node as OUTPUT so the allocator
|
||||
// keeps its memory alive through the backward pass. The backward graph already
|
||||
// contains the forward ops (gb_grad is a superset of gf), so the checkpointed
|
||||
// activations are naturally available for backward matmuls without recomputation.
|
||||
// This prevents the allocator from aliasing those buffers to later ops, cutting
|
||||
// peak activation VRAM at the cost of slightly larger static allocation.
|
||||
if (opt_ctx->grad_checkpoint_interval > 0) {
|
||||
const int interval = opt_ctx->grad_checkpoint_interval;
|
||||
const int n_fwd = opt_ctx->gf->n_nodes;
|
||||
int ckpt_count = 0;
|
||||
for (int i = interval - 1; i < n_fwd; i += interval) {
|
||||
struct ggml_tensor * node = opt_ctx->gf->nodes[i];
|
||||
// Only checkpoint F32 compute nodes — skip I32 index tensors and already-output nodes.
|
||||
if (node->type != GGML_TYPE_F32) continue;
|
||||
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) continue;
|
||||
if (node->flags & GGML_TENSOR_FLAG_INPUT) continue;
|
||||
node->flags |= GGML_TENSOR_FLAG_OUTPUT;
|
||||
ckpt_count++;
|
||||
}
|
||||
if (ckpt_count > 0) {
|
||||
GGML_LOG_DEBUG("%s: gradient checkpointing: marked %d/%d nodes as persistent (interval=%d)\n",
|
||||
__func__, ckpt_count, n_fwd, interval);
|
||||
}
|
||||
}
|
||||
|
||||
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
|
||||
opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
|
||||
ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
|
||||
|
|
@ -556,10 +600,11 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|||
result->build_type_alloc = params.build_type;
|
||||
result->inputs = params.inputs;
|
||||
result->outputs = params.outputs;
|
||||
result->opt_period = params.opt_period;
|
||||
result->get_opt_pars = params.get_opt_pars;
|
||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||
result->optimizer = params.optimizer;
|
||||
result->opt_period = params.opt_period;
|
||||
result->grad_checkpoint_interval = params.grad_checkpoint_interval;
|
||||
result->get_opt_pars = params.get_opt_pars;
|
||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||
result->optimizer = params.optimizer;
|
||||
|
||||
GGML_ASSERT(result->opt_period >= 1);
|
||||
|
||||
|
|
@ -588,6 +633,12 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
|
|||
}
|
||||
ggml_backend_buffer_free(opt_ctx->buf_static);
|
||||
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
||||
for (ggml_backend_buffer_t buf : opt_ctx->bufs_momenta) {
|
||||
ggml_backend_buffer_free(buf);
|
||||
}
|
||||
for (struct ggml_context * ctx : opt_ctx->ctxs_momenta) {
|
||||
ggml_free(ctx);
|
||||
}
|
||||
ggml_free(opt_ctx->ctx_static);
|
||||
ggml_free(opt_ctx->ctx_cpu);
|
||||
delete opt_ctx;
|
||||
|
|
|
|||
|
|
@ -1552,6 +1552,12 @@ extern "C" {
|
|||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
|
||||
enum ggml_opt_optimizer_type optimizer_type;
|
||||
|
||||
// Gradient checkpointing: mark every Nth forward graph node as persistent so the
|
||||
// allocator cannot reuse its memory during backward. Reduces peak activation VRAM
|
||||
// at the cost of ~0 extra compute (activations are kept, not recomputed).
|
||||
// Set to 0 (default) to disable. Good values: 32–64 nodes ≈ every 1–2 transformer layers.
|
||||
int32_t grad_checkpoint_interval;
|
||||
};
|
||||
|
||||
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
|
|
|||
|
|
@ -2618,11 +2618,71 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
|
|||
GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
|
||||
GGML_ASSERT(n_batch % n_ubatch == 0);
|
||||
|
||||
// Recreate the scheduler and gf_res_prev with a training-inflated graph size before
|
||||
// creating opt_ctx, so opt_ctx captures the new (larger) scheduler pointer.
|
||||
// The backward graph (gb_grad) duplicates gf and adds ~2-3x more nodes+leafs;
|
||||
// gb_opt adds optimizer step nodes on top.
|
||||
//
|
||||
// We measure the actual training forward graph node count at n_ubatch here,
|
||||
// then multiply by 4 to cover gf + gb_grad + gb_opt. This is exact for any
|
||||
// model size — no magic constant needed.
|
||||
{
|
||||
uint32_t train_fwd_nodes = 0;
|
||||
|
||||
// Build a real training-ubatch forward graph in split-only mode (no buffer realloc)
|
||||
// so we can count its actual nodes. Fall back to n_tensors formula if it fails.
|
||||
if (memory) {
|
||||
auto mctx_tmp = memory->init_full();
|
||||
if (mctx_tmp) {
|
||||
// graph_reserve() uses gf_res_reserve to build the graph, so both
|
||||
// must be large enough to hold the training forward graph.
|
||||
// Use 16x n_tensors as a generous temporary cap for the measurement pass.
|
||||
const uint32_t tmp_cap = std::max<uint32_t>(4096u, 16u * model->n_tensors());
|
||||
gf_res_prev.reset(new llm_graph_result(tmp_cap));
|
||||
gf_res_reserve.reset(new llm_graph_result(tmp_cap));
|
||||
// split_only=true: only splits the graph, doesn't reallocate compute buffers
|
||||
auto * gf_train = graph_reserve(n_ubatch, 1, n_ubatch, mctx_tmp.get(), /*split_only=*/true);
|
||||
if (gf_train) {
|
||||
train_fwd_nodes = (uint32_t)ggml_graph_n_nodes(gf_train);
|
||||
LLAMA_LOG_INFO("%s: measured training graph nodes = %u (n_ubatch=%u)\n",
|
||||
__func__, train_fwd_nodes, n_ubatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (train_fwd_nodes == 0) {
|
||||
// Fallback: use n_tensors formula
|
||||
train_fwd_nodes = std::max<uint32_t>(1024u, 8u * model->n_tensors());
|
||||
LLAMA_LOG_WARN("%s: could not measure training graph, using fallback nodes=%u\n",
|
||||
__func__, train_fwd_nodes);
|
||||
}
|
||||
|
||||
// gf + gb_grad + gb_opt each need ~train_fwd_nodes; multiply by 4 for safety headroom.
|
||||
// Multiply by 2 again for the scheduler's n_nodes + n_leafs check.
|
||||
const int64_t inflated = (int64_t)std::max<uint32_t>(train_fwd_nodes, 1024u) * 4;
|
||||
const int64_t sched_size = inflated * 2;
|
||||
// Both gf_res_prev and gf_res_reserve are used to build forward graphs
|
||||
// (graph_reserve uses gf_res_reserve; opt_epoch_iter uses gf_res_prev).
|
||||
// Both must have capacity for the full backward graph.
|
||||
gf_res_prev.reset(new llm_graph_result(inflated));
|
||||
gf_res_reserve.reset(new llm_graph_result(inflated));
|
||||
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(),
|
||||
sched_size, cparams.pipeline_parallel, cparams.op_offload));
|
||||
// Suppress the next sched_reserve() call so that llama_decode() during GRPO inference
|
||||
// steps does NOT replace the training sched with a smaller inference sched.
|
||||
// opt_ctx->backend_sched stores a raw pointer to sched.get(); replacing sched while
|
||||
// opt_ctx is alive would leave that pointer dangling and crash on the next opt_epoch.
|
||||
sched_need_reserve = false;
|
||||
LLAMA_LOG_INFO("%s: training graph capacity = %lld (train_fwd_nodes=%u x4)\n",
|
||||
__func__, (long long)inflated, train_fwd_nodes);
|
||||
}
|
||||
|
||||
ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
|
||||
opt_params.opt_period = n_batch / n_ubatch;
|
||||
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
||||
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
||||
opt_params.optimizer = lopt_params.optimizer_type;
|
||||
opt_params.opt_period = n_batch / n_ubatch;
|
||||
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
||||
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
||||
opt_params.optimizer = lopt_params.optimizer_type;
|
||||
opt_params.grad_checkpoint_interval = lopt_params.grad_checkpoint_interval;
|
||||
opt_ctx = ggml_opt_init(opt_params);
|
||||
|
||||
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
||||
|
|
@ -2706,6 +2766,8 @@ void llama_context::opt_epoch_iter(
|
|||
};
|
||||
|
||||
uint32_t pos_batch = 0;
|
||||
static bool timings_printed = false; // print per-ubatch timings only for the first window
|
||||
struct ggml_context * ctx_compute_opt = nullptr;
|
||||
do {
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
||||
|
|
@ -2718,26 +2780,38 @@ void llama_context::opt_epoch_iter(
|
|||
|
||||
auto * res = gf_res_prev.get();
|
||||
|
||||
const int64_t t0_build = ggml_time_ms();
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
res->reset();
|
||||
|
||||
auto * gf = model.build_graph(gparams);
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
// Allocate the tensor metadata context once, then reset it each iteration.
|
||||
// ggml_reset() is much cheaper than ggml_free()+ggml_init() — it just resets the
|
||||
// allocation pointer without freeing/reallocating the backing memory buffer.
|
||||
if (!ctx_compute_opt) {
|
||||
const size_t size_gf = ggml_graph_size(gf);
|
||||
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
|
||||
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 3*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_meta,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ctx_compute_opt = ggml_init(params);
|
||||
if (!timings_printed) {
|
||||
LLAMA_LOG_INFO("%s: [timing] graph capacity=%zu n_nodes=%d size_meta=%.1fMB\n", __func__,
|
||||
size_gf, ggml_graph_n_nodes(gf), (double)size_meta / (1024*1024));
|
||||
}
|
||||
} else {
|
||||
ggml_reset(ctx_compute_opt);
|
||||
}
|
||||
|
||||
const int64_t t1_alloc = ggml_time_ms();
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits());
|
||||
ggml_opt_alloc(opt_ctx, train);
|
||||
|
||||
const int64_t t2_inputs = ggml_time_ms();
|
||||
res->set_inputs(&ubatch);
|
||||
{
|
||||
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
||||
|
|
@ -2753,14 +2827,29 @@ void llama_context::opt_epoch_iter(
|
|||
ggml_backend_tensor_set(labels, &reward_scale, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t3_eval = ggml_time_ms();
|
||||
ggml_opt_eval(opt_ctx, result);
|
||||
|
||||
const int64_t t4_done = ggml_time_ms();
|
||||
if (!timings_printed) {
|
||||
LLAMA_LOG_INFO("%s: [timing] build=%" PRId64 "ms alloc=%" PRId64 "ms inputs=%" PRId64 "ms eval=%" PRId64 "ms total=%" PRId64 "ms\n",
|
||||
__func__,
|
||||
t1_alloc - t0_build,
|
||||
t2_inputs - t1_alloc,
|
||||
t3_eval - t2_inputs,
|
||||
t4_done - t3_eval,
|
||||
t4_done - t0_build);
|
||||
timings_printed = true;
|
||||
}
|
||||
|
||||
if (callback) {
|
||||
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
ggml_free(ctx_compute_opt);
|
||||
|
||||
pos_batch += ubatch.n_tokens;
|
||||
} while (mctx->next());
|
||||
ggml_free(ctx_compute_opt);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue