diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index ed2b8ababf..650a6f026b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1301,7 +1301,9 @@ static void llama_sampler_top_k_backend_apply( ggml_set_name(top_k, "top_k"); if (data->candidates) { - data->candidates = ggml_get_rows(ctx, data->candidates, top_k); + struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); + data->candidates = ggml_get_rows(ctx, candidates_rows, top_k); + data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k); ggml_set_name(data->candidates, "top_k_candidates"); } else { data->candidates = top_k; @@ -1309,9 +1311,8 @@ static void llama_sampler_top_k_backend_apply( struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); - ggml_set_name(top_k_rows, "top_k_rows"); - data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); + ggml_set_name(top_k_rows, "top_k_rows"); GGML_UNUSED(gf); } @@ -1848,14 +1849,14 @@ static void llama_sampler_backend_temp_sampling( ggml_set_name(max_idx, "temp_max_idx"); if (data->candidates) { - data->candidates = ggml_get_rows(ctx, data->candidates, max_idx); + struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); + data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx); } else { data->candidates = max_idx; } - struct ggml_tensor * logits = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); - - data->logits = ggml_get_rows(ctx, logits, max_idx); + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + data->logits = ggml_get_rows(ctx, logits_rows, max_idx); return; }