sampling : remove redundant calls to ggml_build_forward_expand

This commit is contained in:
Georgi Gerganov 2025-12-04 14:25:28 +02:00
parent fce571ee51
commit 1bde70785d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 6 additions and 10 deletions

View File

@ -1078,7 +1078,8 @@ static void llama_sampler_top_k_backend_apply(
ggml_set_name(top_k_rows, "top_k_rows");
data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k);
ggml_build_forward_expand(gf, data->logits);
GGML_UNUSED(gf);
}
static struct llama_sampler_i llama_sampler_top_k_i = {
@ -1264,10 +1265,9 @@ static void llama_sampler_top_p_backend_apply(
ggml_set_name(data->logits, "top_p_logits");
ggml_set_output(data->candidates);
ggml_build_forward_expand(gf, data->candidates);
ggml_set_output(data->logits);
ggml_build_forward_expand(gf, data->logits);
GGML_UNUSED(gf);
}
static struct llama_sampler_i llama_sampler_top_p_i = {
@ -1421,7 +1421,7 @@ static void llama_sampler_min_p_backend_apply(
data->logits = ggml_add(ctx, data->logits, min_p_bias);
ggml_set_name(data->logits, "min_p_logits");
ggml_build_forward_expand(gf, data->logits);
GGML_UNUSED(gf);
}
static struct llama_sampler_i llama_sampler_min_p_i = {
@ -1602,7 +1602,6 @@ static void llama_sampler_backend_temp_sampling(
struct ggml_tensor * logit = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
data->logits = ggml_get_rows(ctx, logit, max_idx);
ggml_build_forward_expand(gf, data->logits);
return;
}
@ -1614,7 +1613,7 @@ static void llama_sampler_backend_temp_sampling(
data->logits = ggml_cont(ctx, scaled);
ggml_set_name(data->logits, "temp_scaled_logits");
ggml_build_forward_expand(gf, data->logits);
GGML_UNUSED(gf);
}
static void llama_sampler_temp_backend_apply(
@ -1807,7 +1806,6 @@ static void llama_sampler_temp_ext_backend_apply(
ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
data->logits = scaled_logits;
ggml_build_forward_expand(gf, data->logits);
}
static struct llama_sampler_i llama_sampler_temp_ext_i = {
@ -3080,8 +3078,6 @@ static void llama_sampler_logit_bias_backend_apply(
// Add the sparse logit logit_bias to the logits
struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias);
data->logits = logit_biased;
ggml_build_forward_expand(gf, logit_biased);
}
static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {