added missing llama_opt_set_reward_weights

This commit is contained in:
Salvatore Rossitto 2026-03-12 11:58:14 +01:00
parent 84cab59ec6
commit 76d5b67980
2 changed files with 11 additions and 2 deletions

View File

@ -1556,6 +1556,12 @@ extern "C" {
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
// weights: array of floats, one per dataset window (indexed by idata), already normalized to [0,1].
// n_weights: length of the array.
// Pass NULL/0 to disable (equivalent to all-ones, i.e. standard SFT).
// The pointer must remain valid for the duration of all llama_opt_epoch calls.
LLAMA_API void llama_opt_set_reward_weights(const float * weights, int64_t n_weights);
LLAMA_API void llama_opt_epoch(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
@ -1563,7 +1569,8 @@ extern "C" {
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
ggml_opt_epoch_callback callback_eval,
bool shuffle);
#ifdef __cplusplus
}

View File

@ -187,7 +187,8 @@ struct llama_context {
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
ggml_opt_epoch_callback callback_eval,
bool shuffle);
void opt_epoch_iter(
ggml_opt_dataset_t dataset,
@ -195,6 +196,7 @@ struct llama_context {
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
float reward_scale,
ggml_opt_epoch_callback callback,
bool train,
int64_t idata_in_loop,