diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 678a7d23ad..d806782108 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1097,11 +1097,11 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ const int64_t n_embd_k_gqa = k->ne[0]; const int64_t n_tokens = k_cur->ne[2]; - k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); + k_cur = ggml_reshape_2d(ctx, k_cur, n_embd_k_gqa, n_tokens); if (k_idxs && supports_set_rows) { if (k->ne[2] > 1) { - k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); + k = ggml_reshape_2d(ctx, k, n_embd_k_gqa, k->ne[1]*k->ne[2]); } return ggml_set_rows(ctx, k, k_cur, k_idxs);