diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5a01c5e065..b0713cab45 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -94,9 +94,8 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return GGML_BACKEND_SPLIT_AXIS_MIRRORED; }; - auto get_split_granularity = [&]() -> int64_t { - // TODO determine this from tensors with AXIS_0 - constexpr int64_t blck_size = 32; + auto get_split_granularity = [&](ggml_backend_meta_split_axis split_axis) -> int64_t { + const int64_t blck_size = split_axis == GGML_BACKEND_SPLIT_AXIS_1 && tensor->ne[1] % 256 == 0 ? 256 : 32; // attention if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_q_bias) ||