diff --git a/src/models/deepseek32.cpp b/src/models/deepseek32.cpp index aad6ecf532..23bb45c534 100644 --- a/src/models/deepseek32.cpp +++ b/src/models/deepseek32.cpp @@ -145,6 +145,11 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_ // get cached indexer keys indexer_k = mctx_cur->get_ik(ctx0, il); + // split the batch into streams if needed + const auto n_stream = indexer_k->ne[3]; + indexer_q = ggml_view_4d(ctx0, indexer_q, indexer_q->ne[0], indexer_q->ne[1], indexer_q->ne[2]/n_stream, n_stream, indexer_q->nb[1], indexer_q->nb[2], indexer_q->nb[3]/n_stream, 0); + indexer_weights = ggml_view_4d(ctx0, indexer_weights, indexer_weights->ne[0], indexer_weights->ne[1]/n_stream, indexer_weights->ne[2], n_stream, indexer_weights->nb[1], indexer_weights->nb[2]/n_stream, indexer_weights->nb[3]/n_stream, 0); + indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3); cb(indexer_q, "indexer_q", il); indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3);