optimize
This commit is contained in:
parent
36b526d768
commit
6934780669
|
|
@ -2349,11 +2349,18 @@ struct llama_sampler_power_law {
|
|||
std::mt19937 rng;
|
||||
|
||||
// historical token probabilities weighted by recency
|
||||
float weighted_sum;
|
||||
float weighted_sum;
|
||||
// sum of weights, converges to 1/(1-decay)
|
||||
float total_weight;
|
||||
float total_weight;
|
||||
// used to store original token probabilities (needed for history update after selection)
|
||||
std::vector<float> original_probs;
|
||||
};
|
||||
|
||||
// transformation constants
|
||||
static constexpr float DISTRIBUTION_WIDTH = 0.3f;
|
||||
static constexpr float PEAK_LOGIT_VALUE = 5.0f;
|
||||
static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH;
|
||||
|
||||
static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "power-law";
|
||||
}
|
||||
|
|
@ -2369,7 +2376,7 @@ static float llama_sampler_power_law_compute_target(const llama_sampler_power_la
|
|||
fprintf(stderr, "power-law: compute_target: target = %.3f\n", target);
|
||||
|
||||
// clamp result to [0.0, 1.0]
|
||||
target = std::max(0.0f, std::min(target, 1.0f));
|
||||
target = std::clamp(target, 0.0f, 1.0f);
|
||||
fprintf(stderr, "power-law: compute_target: target (post-clamp) = %.3f\n", target); fflush(stderr);
|
||||
return target;
|
||||
}
|
||||
|
|
@ -2379,43 +2386,32 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
|||
|
||||
if (ctx->target < 0.0f) {
|
||||
// no-op: just sample from the distribution as-is
|
||||
fprintf(stderr, "power-law: no-op!"); fflush(stderr);
|
||||
fprintf(stderr, "power-law: no-op!");
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
cur_p->selected = idx;
|
||||
return;
|
||||
}
|
||||
|
||||
// clamp decay to avoid degenerate case at 1.0 (unbounded accumulation)
|
||||
const float decay = std::min(ctx->decay, 0.99f);
|
||||
fprintf(stderr, "power-law: decay = %.3f\n", decay); fflush(stderr);
|
||||
|
||||
// get the original probabilities
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
|
||||
// store the original probabilities (needed for history update after selection)
|
||||
std::vector<float> original_probs;
|
||||
original_probs.reserve(cur_p->size);
|
||||
// store the original probabilities
|
||||
ctx->original_probs.resize(cur_p->size);
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
original_probs.push_back(cur_p->data[i].p);
|
||||
ctx->original_probs[i] = cur_p->data[i].p;
|
||||
}
|
||||
|
||||
float computed_target = llama_sampler_power_law_compute_target(ctx);
|
||||
fprintf(stderr, "power-law: computed_target = %.3f\n", computed_target); fflush(stderr);
|
||||
fprintf(stderr, "power-law: computed_target = %.3f\n", computed_target);
|
||||
|
||||
//
|
||||
// power law transform
|
||||
//
|
||||
|
||||
// transformation constants
|
||||
const float distribution_width = 0.3f;
|
||||
const float peak_logit_value = 5.0f;
|
||||
|
||||
const float inv_width = 1.0f / distribution_width;
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
float dist = (cur_p->data[i].p - computed_target) * inv_width;
|
||||
cur_p->data[i].logit = peak_logit_value / (1.0f + dist * dist);
|
||||
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
|
||||
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);
|
||||
}
|
||||
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
|
|
@ -2423,14 +2419,14 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
|
|||
// sample from transformed distribution
|
||||
const int idx = llama_sample_dist(cur_p, ctx->rng);
|
||||
cur_p->selected = idx;
|
||||
fprintf(stderr, "power-law: selected token at index %d\n", idx); fflush(stderr);
|
||||
fprintf(stderr, "power-law: selected token at index %d\n", idx);
|
||||
|
||||
// update running history with the original probability of the selected token
|
||||
float original_p = original_probs[idx];
|
||||
fprintf(stderr, "power-law: original prob was %.3f\n", original_p); fflush(stderr);
|
||||
ctx->weighted_sum = original_p + decay * ctx->weighted_sum;
|
||||
fprintf(stderr, "power-law: updated ctx->weighted_sum = %.3f\n", ctx->weighted_sum); fflush(stderr);
|
||||
ctx->total_weight = 1.0f + decay * ctx->total_weight;
|
||||
float original_p = ctx->original_probs[idx];
|
||||
fprintf(stderr, "power-law: original prob was %.3f\n", original_p);
|
||||
ctx->weighted_sum = original_p + ctx->decay * ctx->weighted_sum;
|
||||
fprintf(stderr, "power-law: updated ctx->weighted_sum = %.3f\n", ctx->weighted_sum);
|
||||
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
|
||||
fprintf(stderr, "power-law: updated ctx->total_weight = %.3f\n", ctx->total_weight); fflush(stderr);
|
||||
}
|
||||
|
||||
|
|
@ -2448,6 +2444,7 @@ static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_s
|
|||
result_ctx->rng = ctx->rng;
|
||||
result_ctx->weighted_sum = ctx->weighted_sum;
|
||||
result_ctx->total_weight = ctx->total_weight;
|
||||
result_ctx->original_probs.reserve(ctx->original_probs.capacity());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
@ -2475,7 +2472,7 @@ struct llama_sampler * llama_sampler_init_power_law(
|
|||
/* .iface = */ &llama_sampler_power_law_i,
|
||||
/* .ctx = */ new llama_sampler_power_law {
|
||||
/* .target = */ target,
|
||||
/* .decay = */ decay,
|
||||
/* .decay = */ std::min(decay, 0.99f),
|
||||
/* .seed = */ seed_cur,
|
||||
/* .rng = */ std::mt19937(seed_cur),
|
||||
/* .weighted_sum = */ 0.0f,
|
||||
|
|
|
|||
Loading…
Reference in New Issue