support for GPT-OSS, Qwen 3 MoE

This commit is contained in:
Johannes Gäßler 2026-02-06 17:09:01 +01:00
parent 39b96f8fe1
commit a630b27da7
2 changed files with 101 additions and 67 deletions

View File

@ -451,35 +451,24 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
}
switch (split_state) {
case GGML_BACKEND_SPLIT_STATE_BY_NE0: {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
const size_t row_size_full = ggml_row_size(tensor->type, tensor->ne[0]);
GGML_ASSERT(offset % row_size_full == 0);
GGML_ASSERT(size % row_size_full == 0);
const int64_t i1_start = offset /row_size_full;
const int64_t i1_stop = (offset + size)/row_size_full;
size_t row_offset_j = 0;
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
case GGML_BACKEND_SPLIT_STATE_BY_NE2: {
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
const size_t chunk_size_full = tensor->nb[int(split_state) + 1];
GGML_ASSERT(offset % chunk_size_full == 0);
GGML_ASSERT(size % chunk_size_full == 0);
const int64_t i_start = offset /chunk_size_full;
const int64_t i_stop = (offset + size)/chunk_size_full;
size_t offset_j = 0;
for (ggml_tensor * t : simple_tensors) {
const size_t row_size_j = ggml_row_size(tensor->type, t->ne[0]);
for (int64_t i1 = i1_start; i1 < i1_stop; i1++) {
ggml_backend_tensor_set(t, (const char *) data + i1*row_size_full + row_offset_j, i1*row_size_j, row_size_j);
const size_t chunk_size_j = t->nb[int(split_state) + 1];
for (int64_t i1 = i_start; i1 < i_stop; i1++) {
ggml_backend_tensor_set(t, (const char *) data + i1*chunk_size_full + offset_j, i1*chunk_size_j, chunk_size_j);
}
row_offset_j += row_size_j;
offset_j += chunk_size_j;
}
GGML_ASSERT(row_offset_j == row_size_full);
} break;
case GGML_BACKEND_SPLIT_STATE_BY_NE1: {
GGML_ASSERT(size == ggml_nbytes(tensor));
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
size_t data_offset_j = 0;
for (ggml_tensor * t : simple_tensors) {
const size_t nbytes_j = ggml_nbytes(t);
ggml_backend_tensor_set(t, (const char *) data + data_offset_j, 0, nbytes_j);
data_offset_j += nbytes_j;
}
GGML_ASSERT(data_offset_j == size);
GGML_ASSERT(offset_j == chunk_size_full);
} break;
case GGML_BACKEND_SPLIT_STATE_MIRRORED: {
for (ggml_tensor * t : simple_tensors) {
@ -507,23 +496,24 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co
}
switch (split_state) {
case GGML_BACKEND_SPLIT_STATE_BY_NE0: {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
const size_t row_size_full = ggml_row_size(tensor->type, tensor->ne[0]);
GGML_ASSERT(offset % row_size_full == 0);
GGML_ASSERT(size % row_size_full == 0);
const int64_t i1_start = offset /row_size_full;
const int64_t i1_stop = (offset + size)/row_size_full;
size_t row_offset_j = 0;
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
case GGML_BACKEND_SPLIT_STATE_BY_NE2: {
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
const size_t chunk_size_full = tensor->nb[int(split_state) + 1];
GGML_ASSERT(offset % chunk_size_full == 0);
GGML_ASSERT(size % chunk_size_full == 0);
const int64_t i_start = offset /chunk_size_full;
const int64_t i_stop = (offset + size)/chunk_size_full;
size_t offset_j = 0;
for (ggml_tensor * t : simple_tensors) {
const size_t row_size_j = ggml_row_size(tensor->type, t->ne[0]);
for (int64_t i1 = i1_start; i1 < i1_stop; i1++) {
ggml_backend_tensor_set(t, (const char *) data + i1*row_size_full + row_offset_j, i1*row_size_j, row_size_j);
const size_t chunk_size_j = t->nb[int(split_state) + 1];
for (int64_t i1 = i_start; i1 < i_stop; i1++) {
ggml_backend_tensor_get(t, (char *) data + i1*chunk_size_full + offset_j, i1*chunk_size_j, chunk_size_j);
}
row_offset_j += row_size_j;
offset_j += chunk_size_j;
}
GGML_ASSERT(row_offset_j == row_size_full);
GGML_ASSERT(offset_j == chunk_size_full);
} break;
case GGML_BACKEND_SPLIT_STATE_MIRRORED: {
// TODO other simple backend may be better
@ -986,6 +976,19 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
return src_split_states[0];
};
// Some ops broadcast the src1 data across src0:
auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
if (src_split_states[0] >= 0 && src_split_states[0] < GGML_MAX_DIMS &&
tensor->src[1]->ne[int(src_split_states[0])] == 1 && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
return src_split_states[0];
}
if (src_split_states[0] == src_split_states[1] && src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
return src_split_states[0]; // GGML_ADD_ID
}
GGML_ASSERT(tensor->src[2] == nullptr || src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED);
return handle_generic(src_split_states, /*scalar_only =*/ false);
};
auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
@ -1023,8 +1026,7 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
}
case GGML_BACKEND_SPLIT_STATE_MIRRORED:
case GGML_BACKEND_SPLIT_STATE_PARTIAL: {
GGML_ABORT("reshape not implemented for MIRRORED/PARTIAL");
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
return src_split_states[0];
}
default: {
GGML_ABORT("fatal error");
@ -1033,6 +1035,17 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
}
};
auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
if (ggml_is_contiguous(tensor)) {
return handle_reshape(src_split_states);
}
if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED || src_split_states[0] == GGML_BACKEND_SPLIT_STATE_PARTIAL) {
return src_split_states[0];
}
GGML_ABORT("non-contioguos view not implemented");
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
};
auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
switch (src_split_states[0]) {
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
@ -1065,9 +1078,11 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
};
auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
GGML_ASSERT(src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
GGML_ASSERT(src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
GGML_ASSERT(src_split_states[2] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
GGML_ASSERT( src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
GGML_ASSERT( src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
GGML_ASSERT( src_split_states[2] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[3] == GGML_BACKEND_SPLIT_STATE_MIRRORED);
GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[4] == GGML_BACKEND_SPLIT_STATE_BY_NE0);
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
};
@ -1094,17 +1109,19 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
case GGML_OP_DUP: {
return handle_generic(src_split_states, /*scalar_only =*/ true);
}
case GGML_OP_ADD: {
return handle_generic(src_split_states, /*scalar_only =*/ false);
}
case GGML_OP_ADD:
case GGML_OP_ADD_ID: {
return handle_generic(src_split_states, /*scalar_only =*/ true);
return handle_bin_bcast(src_split_states);
}
case GGML_OP_ADD1:
case GGML_OP_ACC:
case GGML_OP_ACC: {
return handle_generic(src_split_states, /*scalar_only =*/ true);
}
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_DIV: {
return handle_bin_bcast(src_split_states);
}
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
@ -1137,10 +1154,10 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
case GGML_OP_L2_NORM: {
return handle_per_row(src_split_states);
}
case GGML_OP_MUL_MAT: {
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: {
return handle_mul_mat(src_split_states);
}
case GGML_OP_MUL_MAT_ID:
case GGML_OP_OUT_PROD: {
return handle_generic(src_split_states, /*scalar_only =*/ true);
}
@ -1156,11 +1173,7 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
return handle_reshape(src_split_states);
}
case GGML_OP_VIEW: {
if (ggml_is_contiguous(tensor)) {
return handle_reshape(src_split_states);
}
GGML_ABORT("non-contioguos view not implemented");
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
return handle_view(src_split_states);
}
case GGML_OP_PERMUTE: {
return handle_permute(src_split_states);

View File

@ -886,28 +886,49 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
static enum ggml_backend_meta_split_state llama_meta_device_get_tensor_split(const struct ggml_tensor * tensor, void * userdata) {
// attention
const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).*");
const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight");
if (std::regex_match(tensor->name, pattern_qkv_weight)) {
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
}
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
if (std::regex_match(tensor->name, pattern_kv_cache)) {
const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias");
if (std::regex_match(tensor->name, pattern_qkv_bias)) {
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
}
const std::regex pattern_attn_out("blk\\.\\d*\\.attn_output.*");
if (std::regex_match(tensor->name, pattern_attn_out)) {
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
if (std::regex_match(tensor->name, pattern_qk_norm)) {
return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_STATE_MIRRORED : GGML_BACKEND_SPLIT_STATE_BY_NE1;
}
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) {
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
}
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
if (std::regex_match(tensor->name, pattern_attn_out_weight)) {
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
}
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
if (std::regex_match(tensor->name, pattern_attn_out_bias)) {
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
}
// FFN
const std::regex pattern_ffn_up_gate("blk\\.\\d*\\.ffn_(up|gate).*");
if (std::regex_match(tensor->name, pattern_ffn_up_gate)) {
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) {
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
}
const std::regex pattern_ffn_down("blk\\.\\d*\\.ffn_down.*");
if (std::regex_match(tensor->name, pattern_ffn_down)) {
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) {
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
}
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
if (std::regex_match(tensor->name, pattern_ffn_down_weight)) {
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
}
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
if (std::regex_match(tensor->name, pattern_ffn_down_bias)) {
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
}
// output
const std::regex pattern_output("output");