added reward scaling to opt_epoch_iter calls

This commit is contained in:
Salvatore Rossitto 2026-03-12 12:04:34 +01:00
parent 76d5b67980
commit 70730e8d28
2 changed files with 32 additions and 8 deletions

View File

@ -83,7 +83,7 @@ int main(int argc, char ** argv) {
for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar, /*shuffle=*/false);
fprintf(stderr, "\n");
ggml_opt_result_reset(result_train);

View File

@ -2657,6 +2657,7 @@ void llama_context::opt_epoch_iter(
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
float reward_scale,
ggml_opt_epoch_callback callback,
bool train,
int64_t idata_in_loop,
@ -2742,11 +2743,14 @@ void llama_context::opt_epoch_iter(
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
GGML_ASSERT(labels->ne[1] == n_ubatch);
ggml_set_zero(labels);
const float onef = 1.0f;
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
// -1 sentinel means "masked position" (prompt token, BOS separator, etc).
// Leave the label tensor zeroed at this position → zero cross-entropy
// contribution. Do NOT write anything — ggml_set_zero already handled it.
if (labels_sparse[ilabel] < 0) continue;
GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
ggml_backend_tensor_set(labels, &reward_scale, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
}
}
ggml_opt_eval(opt_ctx, result);
@ -2760,13 +2764,25 @@ void llama_context::opt_epoch_iter(
}
}
// Optional per-window reward weights for reward-weighted SFT.
// Set via llama_opt_set_reward_weights() before calling llama_opt_epoch().
// Null/0 means all rewards are 1.0 (standard SFT).
static thread_local const float * g_reward_weights = nullptr;
static thread_local int64_t g_reward_weights_n = 0;
void llama_opt_set_reward_weights(const float * weights, int64_t n_weights) {
g_reward_weights = weights;
g_reward_weights_n = n_weights;
}
void llama_context::opt_epoch(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval) {
ggml_opt_epoch_callback callback_eval,
bool shuffle) {
const uint32_t n_ctx = this->n_ctx();
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
@ -2775,6 +2791,10 @@ void llama_context::opt_epoch(
GGML_ASSERT(idata_split >= 0);
GGML_ASSERT(idata_split <= ndata);
if (shuffle && idata_split > 1) {
ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);
}
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
@ -2788,9 +2808,11 @@ void llama_context::opt_epoch(
for (; idata < idata_split; ++idata) {
constexpr bool train = true;
const int64_t idata_in_loop = idata*ubatch_per_ctx;
const float reward = (g_reward_weights && idata < g_reward_weights_n)
? g_reward_weights[idata] : 1.0f;
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, reward,
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
}
@ -2801,7 +2823,7 @@ void llama_context::opt_epoch(
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, 1.0f,
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
}
@ -3550,12 +3572,14 @@ void llama_opt_epoch(
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval) {
ggml_opt_epoch_callback callback_eval,
bool shuffle) {
ctx->opt_epoch(
dataset,
result_train,
result_eval,
idata_split,
callback_train,
callback_eval);
callback_eval,
shuffle);
}