From a630b27da765f370b3331c9421ab3078c76916cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 6 Feb 2026 17:09:01 +0100 Subject: [PATCH] support for GPT-OSS, Qwen 3 MoE --- ggml/src/ggml-backend-meta.cpp | 129 ++++++++++++++++++--------------- src/llama.cpp | 39 +++++++--- 2 files changed, 101 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 635e718356..7d020fa0e0 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -451,35 +451,24 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg } switch (split_state) { - case GGML_BACKEND_SPLIT_STATE_BY_NE0: { - GGML_ASSERT(tensor->ne[2] == 1); - GGML_ASSERT(tensor->ne[3] == 1); - const size_t row_size_full = ggml_row_size(tensor->type, tensor->ne[0]); - GGML_ASSERT(offset % row_size_full == 0); - GGML_ASSERT(size % row_size_full == 0); - const int64_t i1_start = offset /row_size_full; - const int64_t i1_stop = (offset + size)/row_size_full; - size_t row_offset_j = 0; + case GGML_BACKEND_SPLIT_STATE_BY_NE0: + case GGML_BACKEND_SPLIT_STATE_BY_NE1: + case GGML_BACKEND_SPLIT_STATE_BY_NE2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[int(split_state) + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; for (ggml_tensor * t : simple_tensors) { - const size_t row_size_j = ggml_row_size(tensor->type, t->ne[0]); - for (int64_t i1 = i1_start; i1 < i1_stop; i1++) { - ggml_backend_tensor_set(t, (const char *) data + i1*row_size_full + row_offset_j, i1*row_size_j, row_size_j); + const size_t chunk_size_j = t->nb[int(split_state) + 1]; + for (int64_t i1 = i_start; i1 < i_stop; i1++) { + ggml_backend_tensor_set(t, (const char *) data + i1*chunk_size_full + offset_j, i1*chunk_size_j, chunk_size_j); } - row_offset_j += row_size_j; + offset_j += chunk_size_j; } - GGML_ASSERT(row_offset_j == row_size_full); - } break; - case GGML_BACKEND_SPLIT_STATE_BY_NE1: { - GGML_ASSERT(size == ggml_nbytes(tensor)); - GGML_ASSERT(tensor->ne[2] == 1); - GGML_ASSERT(tensor->ne[3] == 1); - size_t data_offset_j = 0; - for (ggml_tensor * t : simple_tensors) { - const size_t nbytes_j = ggml_nbytes(t); - ggml_backend_tensor_set(t, (const char *) data + data_offset_j, 0, nbytes_j); - data_offset_j += nbytes_j; - } - GGML_ASSERT(data_offset_j == size); + GGML_ASSERT(offset_j == chunk_size_full); } break; case GGML_BACKEND_SPLIT_STATE_MIRRORED: { for (ggml_tensor * t : simple_tensors) { @@ -507,23 +496,24 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co } switch (split_state) { - case GGML_BACKEND_SPLIT_STATE_BY_NE0: { - GGML_ASSERT(tensor->ne[2] == 1); - GGML_ASSERT(tensor->ne[3] == 1); - const size_t row_size_full = ggml_row_size(tensor->type, tensor->ne[0]); - GGML_ASSERT(offset % row_size_full == 0); - GGML_ASSERT(size % row_size_full == 0); - const int64_t i1_start = offset /row_size_full; - const int64_t i1_stop = (offset + size)/row_size_full; - size_t row_offset_j = 0; + case GGML_BACKEND_SPLIT_STATE_BY_NE0: + case GGML_BACKEND_SPLIT_STATE_BY_NE1: + case GGML_BACKEND_SPLIT_STATE_BY_NE2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[int(split_state) + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; for (ggml_tensor * t : simple_tensors) { - const size_t row_size_j = ggml_row_size(tensor->type, t->ne[0]); - for (int64_t i1 = i1_start; i1 < i1_stop; i1++) { - ggml_backend_tensor_set(t, (const char *) data + i1*row_size_full + row_offset_j, i1*row_size_j, row_size_j); + const size_t chunk_size_j = t->nb[int(split_state) + 1]; + for (int64_t i1 = i_start; i1 < i_stop; i1++) { + ggml_backend_tensor_get(t, (char *) data + i1*chunk_size_full + offset_j, i1*chunk_size_j, chunk_size_j); } - row_offset_j += row_size_j; + offset_j += chunk_size_j; } - GGML_ASSERT(row_offset_j == row_size_full); + GGML_ASSERT(offset_j == chunk_size_full); } break; case GGML_BACKEND_SPLIT_STATE_MIRRORED: { // TODO other simple backend may be better @@ -986,6 +976,19 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc return src_split_states[0]; }; + // Some ops broadcast the src1 data across src0: + auto handle_bin_bcast = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + if (src_split_states[0] >= 0 && src_split_states[0] < GGML_MAX_DIMS && + tensor->src[1]->ne[int(src_split_states[0])] == 1 && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) { + return src_split_states[0]; + } + if (src_split_states[0] == src_split_states[1] && src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED) { + return src_split_states[0]; // GGML_ADD_ID + } + GGML_ASSERT(tensor->src[2] == nullptr || src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED); + return handle_generic(src_split_states, /*scalar_only =*/ false); + }; + auto handle_mul_mat = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) { return GGML_BACKEND_SPLIT_STATE_MIRRORED; @@ -1023,8 +1026,7 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc } case GGML_BACKEND_SPLIT_STATE_MIRRORED: case GGML_BACKEND_SPLIT_STATE_PARTIAL: { - GGML_ABORT("reshape not implemented for MIRRORED/PARTIAL"); - return GGML_BACKEND_SPLIT_STATE_UNKNOWN; + return src_split_states[0]; } default: { GGML_ABORT("fatal error"); @@ -1033,6 +1035,17 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc } }; + auto handle_view = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + if (ggml_is_contiguous(tensor)) { + return handle_reshape(src_split_states); + } + if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED || src_split_states[0] == GGML_BACKEND_SPLIT_STATE_PARTIAL) { + return src_split_states[0]; + } + GGML_ABORT("non-contioguos view not implemented"); + return GGML_BACKEND_SPLIT_STATE_UNKNOWN; + }; + auto handle_permute = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { switch (src_split_states[0]) { case GGML_BACKEND_SPLIT_STATE_BY_NE0: @@ -1065,9 +1078,11 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc }; auto handle_flash_attn_ext = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - GGML_ASSERT(src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE2); - GGML_ASSERT(src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE2); - GGML_ASSERT(src_split_states[2] == GGML_BACKEND_SPLIT_STATE_BY_NE2); + GGML_ASSERT( src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE2); + GGML_ASSERT( src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE2); + GGML_ASSERT( src_split_states[2] == GGML_BACKEND_SPLIT_STATE_BY_NE2); + GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[3] == GGML_BACKEND_SPLIT_STATE_MIRRORED); + GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[4] == GGML_BACKEND_SPLIT_STATE_BY_NE0); return GGML_BACKEND_SPLIT_STATE_BY_NE1; }; @@ -1094,17 +1109,19 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc case GGML_OP_DUP: { return handle_generic(src_split_states, /*scalar_only =*/ true); } - case GGML_OP_ADD: { - return handle_generic(src_split_states, /*scalar_only =*/ false); - } + case GGML_OP_ADD: case GGML_OP_ADD_ID: { - return handle_generic(src_split_states, /*scalar_only =*/ true); + return handle_bin_bcast(src_split_states); } case GGML_OP_ADD1: - case GGML_OP_ACC: + case GGML_OP_ACC: { + return handle_generic(src_split_states, /*scalar_only =*/ true); + } case GGML_OP_SUB: case GGML_OP_MUL: - case GGML_OP_DIV: + case GGML_OP_DIV: { + return handle_bin_bcast(src_split_states); + } case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_LOG: @@ -1137,10 +1154,10 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc case GGML_OP_L2_NORM: { return handle_per_row(src_split_states); } - case GGML_OP_MUL_MAT: { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { return handle_mul_mat(src_split_states); } - case GGML_OP_MUL_MAT_ID: case GGML_OP_OUT_PROD: { return handle_generic(src_split_states, /*scalar_only =*/ true); } @@ -1156,11 +1173,7 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc return handle_reshape(src_split_states); } case GGML_OP_VIEW: { - if (ggml_is_contiguous(tensor)) { - return handle_reshape(src_split_states); - } - GGML_ABORT("non-contioguos view not implemented"); - return GGML_BACKEND_SPLIT_STATE_UNKNOWN; + return handle_view(src_split_states); } case GGML_OP_PERMUTE: { return handle_permute(src_split_states); diff --git a/src/llama.cpp b/src/llama.cpp index 18dacd1848..6f5d91c999 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -886,28 +886,49 @@ static int llama_model_load(const std::string & fname, std::vector static enum ggml_backend_meta_split_state llama_meta_device_get_tensor_split(const struct ggml_tensor * tensor, void * userdata) { // attention - const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).*"); + const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight"); if (std::regex_match(tensor->name, pattern_qkv_weight)) { return GGML_BACKEND_SPLIT_STATE_BY_NE1; } - const std::regex pattern_kv_cache("cache_(k|v)_l\\d*"); - if (std::regex_match(tensor->name, pattern_kv_cache)) { + const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias"); + if (std::regex_match(tensor->name, pattern_qkv_bias)) { return GGML_BACKEND_SPLIT_STATE_BY_NE0; } - const std::regex pattern_attn_out("blk\\.\\d*\\.attn_output.*"); - if (std::regex_match(tensor->name, pattern_attn_out)) { + const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); + if (std::regex_match(tensor->name, pattern_qk_norm)) { + return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_STATE_MIRRORED : GGML_BACKEND_SPLIT_STATE_BY_NE1; + } + const std::regex pattern_kv_cache("cache_(k|v)_l\\d*"); + const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight"); + if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) { return GGML_BACKEND_SPLIT_STATE_BY_NE0; } + const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight"); + if (std::regex_match(tensor->name, pattern_attn_out_weight)) { + return GGML_BACKEND_SPLIT_STATE_BY_NE0; + } + const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias"); + if (std::regex_match(tensor->name, pattern_attn_out_bias)) { + return GGML_BACKEND_SPLIT_STATE_MIRRORED; + } // FFN - const std::regex pattern_ffn_up_gate("blk\\.\\d*\\.ffn_(up|gate).*"); - if (std::regex_match(tensor->name, pattern_ffn_up_gate)) { + const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); + if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) { return GGML_BACKEND_SPLIT_STATE_BY_NE1; } - const std::regex pattern_ffn_down("blk\\.\\d*\\.ffn_down.*"); - if (std::regex_match(tensor->name, pattern_ffn_down)) { + const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); + if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) { return GGML_BACKEND_SPLIT_STATE_BY_NE0; } + const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight"); + if (std::regex_match(tensor->name, pattern_ffn_down_weight)) { + return GGML_BACKEND_SPLIT_STATE_BY_NE0; + } + const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias"); + if (std::regex_match(tensor->name, pattern_ffn_down_bias)) { + return GGML_BACKEND_SPLIT_STATE_MIRRORED; + } // output const std::regex pattern_output("output");