sampling : simplify
This commit is contained in:
parent
8cac9dee45
commit
2464d1b3fc
|
|
@ -149,8 +149,7 @@ static void llama_sampler_backend_top_k_apply_ggml(
|
|||
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
|
||||
}
|
||||
|
||||
// TODO: temporary cont until https://github.com/ggml-org/llama.cpp/pull/17365 is merged
|
||||
ggml_data->candidates = ggml_cont(ctx, top_k);
|
||||
ggml_data->candidates = top_k;
|
||||
|
||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
|
||||
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
||||
|
|
@ -525,12 +524,8 @@ static void llama_sampler_backend_min_p_apply_ggml(
|
|||
struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
|
||||
ggml_set_name(threshold, "min_p_threshold");
|
||||
|
||||
// Broadcast the threshold to match the shape of logits.
|
||||
struct ggml_tensor * threshold_b = ggml_repeat(ctx, threshold, ggml_data->logits);
|
||||
ggml_set_name(threshold_b, "min_p_threshold_b");
|
||||
|
||||
// Subtract the threshold from logits.
|
||||
struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold_b);
|
||||
struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold);
|
||||
|
||||
// Create a mask where logits below the threshold are 0 (discard),
|
||||
// and others are 1 (keep).
|
||||
|
|
@ -713,4 +708,4 @@ struct llama_sampler * llama_sampler_backend_init_top_p(float p) {
|
|||
};
|
||||
|
||||
return sampler;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue