remove debug logging, explicitly clamp params at init
This commit is contained in:
parent
85b6e52e39
commit
fcb5129086
|
|
@ -2370,10 +2370,8 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
||||||
|
|
||||||
if (ctx->target < 0.0f) {
|
if (ctx->target < 0.0f) {
|
||||||
// no-op: just sample from the distribution as-is
|
// no-op: just sample from the distribution as-is
|
||||||
fprintf(stderr, "power-law: no-op!\n"); fflush(stderr);
|
|
||||||
llama_sampler_softmax_impl(cur_p, false);
|
llama_sampler_softmax_impl(cur_p, false);
|
||||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
||||||
cur_p->selected = idx;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2389,13 +2387,8 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
||||||
ctx->total_weight == 0.0f ? ctx->target : 2.0f * ctx->target - (ctx->weighted_sum / ctx->total_weight),
|
ctx->total_weight == 0.0f ? ctx->target : 2.0f * ctx->target - (ctx->weighted_sum / ctx->total_weight),
|
||||||
0.0f, 1.0f
|
0.0f, 1.0f
|
||||||
);
|
);
|
||||||
fprintf(stderr, "power-law: computed target = %.3f\n", computed_target);
|
|
||||||
|
|
||||||
//
|
|
||||||
// power law transform
|
// power law transform
|
||||||
//
|
|
||||||
|
|
||||||
fprintf(stderr, "power-law: cur_p->size = %d\n", (int)cur_p->size);
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
|
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
|
||||||
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);
|
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);
|
||||||
|
|
@ -2406,14 +2399,10 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
||||||
// sample from transformed distribution
|
// sample from transformed distribution
|
||||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||||
cur_p->selected = idx;
|
cur_p->selected = idx;
|
||||||
fprintf(stderr, "power-law: selected token at index %d\n", idx);
|
|
||||||
|
|
||||||
// update running history with the original probability of the selected token
|
// update running history with the original probability of the selected token
|
||||||
float original_p = ctx->original_probs[idx];
|
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum;
|
||||||
ctx->weighted_sum = original_p + ctx->decay * ctx->weighted_sum;
|
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; // history fades over time
|
||||||
fprintf(stderr, "power-law: updated ctx->weighted_sum = %.3f\n", ctx->weighted_sum);
|
|
||||||
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
|
|
||||||
fprintf(stderr, "power-law: updated ctx->total_weight = %.3f\n", ctx->total_weight); fflush(stderr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
|
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
|
||||||
|
|
@ -2453,15 +2442,12 @@ struct llama_sampler * llama_sampler_init_power_law(
|
||||||
float decay,
|
float decay,
|
||||||
uint32_t seed
|
uint32_t seed
|
||||||
) {
|
) {
|
||||||
const float _decay = std::min(decay, 0.99f);
|
|
||||||
fprintf(stderr, "power-law: init: target %.3f, decay %.3f\n", (double)target, (double)_decay);
|
|
||||||
fflush(stderr);
|
|
||||||
auto seed_cur = get_rng_seed(seed);
|
auto seed_cur = get_rng_seed(seed);
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_power_law_i,
|
/* .iface = */ &llama_sampler_power_law_i,
|
||||||
/* .ctx = */ new llama_sampler_power_law {
|
/* .ctx = */ new llama_sampler_power_law {
|
||||||
/* .target = */ target,
|
/* .target = */ std::clamp(target, 0.0f, 1.0f),
|
||||||
/* .decay = */ _decay,
|
/* .decay = */ std::clamp(decay, 0.0f, 0.99f),
|
||||||
/* .seed = */ seed_cur,
|
/* .seed = */ seed_cur,
|
||||||
/* .rng = */ std::mt19937(seed_cur),
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
/* .weighted_sum = */ 0.0f,
|
/* .weighted_sum = */ 0.0f,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue