sampling : fix reshapes
This commit is contained in:
parent
5d2156e893
commit
610e50a17d
|
|
@ -1301,7 +1301,9 @@ static void llama_sampler_top_k_backend_apply(
|
|||
ggml_set_name(top_k, "top_k");
|
||||
|
||||
if (data->candidates) {
|
||||
data->candidates = ggml_get_rows(ctx, data->candidates, top_k);
|
||||
struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
|
||||
data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
|
||||
data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
|
||||
ggml_set_name(data->candidates, "top_k_candidates");
|
||||
} else {
|
||||
data->candidates = top_k;
|
||||
|
|
@ -1309,9 +1311,8 @@ static void llama_sampler_top_k_backend_apply(
|
|||
|
||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
||||
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
||||
ggml_set_name(top_k_rows, "top_k_rows");
|
||||
|
||||
data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
|
||||
ggml_set_name(top_k_rows, "top_k_rows");
|
||||
|
||||
GGML_UNUSED(gf);
|
||||
}
|
||||
|
|
@ -1848,14 +1849,14 @@ static void llama_sampler_backend_temp_sampling(
|
|||
ggml_set_name(max_idx, "temp_max_idx");
|
||||
|
||||
if (data->candidates) {
|
||||
data->candidates = ggml_get_rows(ctx, data->candidates, max_idx);
|
||||
struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
|
||||
data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
|
||||
} else {
|
||||
data->candidates = max_idx;
|
||||
}
|
||||
|
||||
struct ggml_tensor * logits = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
||||
|
||||
data->logits = ggml_get_rows(ctx, logits, max_idx);
|
||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
||||
data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
|
||||
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue