support for GPT-OSS, Qwen 3 MoE
This commit is contained in:
parent
39b96f8fe1
commit
a630b27da7
|
|
@ -451,35 +451,24 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
|
|||
}
|
||||
|
||||
switch (split_state) {
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0: {
|
||||
GGML_ASSERT(tensor->ne[2] == 1);
|
||||
GGML_ASSERT(tensor->ne[3] == 1);
|
||||
const size_t row_size_full = ggml_row_size(tensor->type, tensor->ne[0]);
|
||||
GGML_ASSERT(offset % row_size_full == 0);
|
||||
GGML_ASSERT(size % row_size_full == 0);
|
||||
const int64_t i1_start = offset /row_size_full;
|
||||
const int64_t i1_stop = (offset + size)/row_size_full;
|
||||
size_t row_offset_j = 0;
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE2: {
|
||||
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
||||
const size_t chunk_size_full = tensor->nb[int(split_state) + 1];
|
||||
GGML_ASSERT(offset % chunk_size_full == 0);
|
||||
GGML_ASSERT(size % chunk_size_full == 0);
|
||||
const int64_t i_start = offset /chunk_size_full;
|
||||
const int64_t i_stop = (offset + size)/chunk_size_full;
|
||||
size_t offset_j = 0;
|
||||
for (ggml_tensor * t : simple_tensors) {
|
||||
const size_t row_size_j = ggml_row_size(tensor->type, t->ne[0]);
|
||||
for (int64_t i1 = i1_start; i1 < i1_stop; i1++) {
|
||||
ggml_backend_tensor_set(t, (const char *) data + i1*row_size_full + row_offset_j, i1*row_size_j, row_size_j);
|
||||
const size_t chunk_size_j = t->nb[int(split_state) + 1];
|
||||
for (int64_t i1 = i_start; i1 < i_stop; i1++) {
|
||||
ggml_backend_tensor_set(t, (const char *) data + i1*chunk_size_full + offset_j, i1*chunk_size_j, chunk_size_j);
|
||||
}
|
||||
row_offset_j += row_size_j;
|
||||
offset_j += chunk_size_j;
|
||||
}
|
||||
GGML_ASSERT(row_offset_j == row_size_full);
|
||||
} break;
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE1: {
|
||||
GGML_ASSERT(size == ggml_nbytes(tensor));
|
||||
GGML_ASSERT(tensor->ne[2] == 1);
|
||||
GGML_ASSERT(tensor->ne[3] == 1);
|
||||
size_t data_offset_j = 0;
|
||||
for (ggml_tensor * t : simple_tensors) {
|
||||
const size_t nbytes_j = ggml_nbytes(t);
|
||||
ggml_backend_tensor_set(t, (const char *) data + data_offset_j, 0, nbytes_j);
|
||||
data_offset_j += nbytes_j;
|
||||
}
|
||||
GGML_ASSERT(data_offset_j == size);
|
||||
GGML_ASSERT(offset_j == chunk_size_full);
|
||||
} break;
|
||||
case GGML_BACKEND_SPLIT_STATE_MIRRORED: {
|
||||
for (ggml_tensor * t : simple_tensors) {
|
||||
|
|
@ -507,23 +496,24 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co
|
|||
}
|
||||
|
||||
switch (split_state) {
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0: {
|
||||
GGML_ASSERT(tensor->ne[2] == 1);
|
||||
GGML_ASSERT(tensor->ne[3] == 1);
|
||||
const size_t row_size_full = ggml_row_size(tensor->type, tensor->ne[0]);
|
||||
GGML_ASSERT(offset % row_size_full == 0);
|
||||
GGML_ASSERT(size % row_size_full == 0);
|
||||
const int64_t i1_start = offset /row_size_full;
|
||||
const int64_t i1_stop = (offset + size)/row_size_full;
|
||||
size_t row_offset_j = 0;
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE2: {
|
||||
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
||||
const size_t chunk_size_full = tensor->nb[int(split_state) + 1];
|
||||
GGML_ASSERT(offset % chunk_size_full == 0);
|
||||
GGML_ASSERT(size % chunk_size_full == 0);
|
||||
const int64_t i_start = offset /chunk_size_full;
|
||||
const int64_t i_stop = (offset + size)/chunk_size_full;
|
||||
size_t offset_j = 0;
|
||||
for (ggml_tensor * t : simple_tensors) {
|
||||
const size_t row_size_j = ggml_row_size(tensor->type, t->ne[0]);
|
||||
for (int64_t i1 = i1_start; i1 < i1_stop; i1++) {
|
||||
ggml_backend_tensor_set(t, (const char *) data + i1*row_size_full + row_offset_j, i1*row_size_j, row_size_j);
|
||||
const size_t chunk_size_j = t->nb[int(split_state) + 1];
|
||||
for (int64_t i1 = i_start; i1 < i_stop; i1++) {
|
||||
ggml_backend_tensor_get(t, (char *) data + i1*chunk_size_full + offset_j, i1*chunk_size_j, chunk_size_j);
|
||||
}
|
||||
row_offset_j += row_size_j;
|
||||
offset_j += chunk_size_j;
|
||||
}
|
||||
GGML_ASSERT(row_offset_j == row_size_full);
|
||||
GGML_ASSERT(offset_j == chunk_size_full);
|
||||
} break;
|
||||
case GGML_BACKEND_SPLIT_STATE_MIRRORED: {
|
||||
// TODO other simple backend may be better
|
||||
|
|
@ -986,6 +976,19 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
|
|||
return src_split_states[0];
|
||||
};
|
||||
|
||||
// Some ops broadcast the src1 data across src0:
|
||||
auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
if (src_split_states[0] >= 0 && src_split_states[0] < GGML_MAX_DIMS &&
|
||||
tensor->src[1]->ne[int(src_split_states[0])] == 1 && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
|
||||
return src_split_states[0];
|
||||
}
|
||||
if (src_split_states[0] == src_split_states[1] && src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
|
||||
return src_split_states[0]; // GGML_ADD_ID
|
||||
}
|
||||
GGML_ASSERT(tensor->src[2] == nullptr || src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED);
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
};
|
||||
|
||||
auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
|
||||
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
|
||||
|
|
@ -1023,8 +1026,7 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
|
|||
}
|
||||
case GGML_BACKEND_SPLIT_STATE_MIRRORED:
|
||||
case GGML_BACKEND_SPLIT_STATE_PARTIAL: {
|
||||
GGML_ABORT("reshape not implemented for MIRRORED/PARTIAL");
|
||||
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
return src_split_states[0];
|
||||
}
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -1033,6 +1035,17 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
|
|||
}
|
||||
};
|
||||
|
||||
auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
if (ggml_is_contiguous(tensor)) {
|
||||
return handle_reshape(src_split_states);
|
||||
}
|
||||
if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED || src_split_states[0] == GGML_BACKEND_SPLIT_STATE_PARTIAL) {
|
||||
return src_split_states[0];
|
||||
}
|
||||
GGML_ABORT("non-contioguos view not implemented");
|
||||
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
};
|
||||
|
||||
auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
switch (src_split_states[0]) {
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
|
||||
|
|
@ -1065,9 +1078,11 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
|
|||
};
|
||||
|
||||
auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
GGML_ASSERT(src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
|
||||
GGML_ASSERT(src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
|
||||
GGML_ASSERT(src_split_states[2] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
|
||||
GGML_ASSERT( src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
|
||||
GGML_ASSERT( src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
|
||||
GGML_ASSERT( src_split_states[2] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
|
||||
GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[3] == GGML_BACKEND_SPLIT_STATE_MIRRORED);
|
||||
GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[4] == GGML_BACKEND_SPLIT_STATE_BY_NE0);
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
|
||||
};
|
||||
|
||||
|
|
@ -1094,17 +1109,19 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
|
|||
case GGML_OP_DUP: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
case GGML_OP_ADD: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
}
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD_ID: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
return handle_bin_bcast(src_split_states);
|
||||
}
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_ACC: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_DIV: {
|
||||
return handle_bin_bcast(src_split_states);
|
||||
}
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_LOG:
|
||||
|
|
@ -1137,10 +1154,10 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
|
|||
case GGML_OP_L2_NORM: {
|
||||
return handle_per_row(src_split_states);
|
||||
}
|
||||
case GGML_OP_MUL_MAT: {
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID: {
|
||||
return handle_mul_mat(src_split_states);
|
||||
}
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
case GGML_OP_OUT_PROD: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
|
|
@ -1156,11 +1173,7 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
|
|||
return handle_reshape(src_split_states);
|
||||
}
|
||||
case GGML_OP_VIEW: {
|
||||
if (ggml_is_contiguous(tensor)) {
|
||||
return handle_reshape(src_split_states);
|
||||
}
|
||||
GGML_ABORT("non-contioguos view not implemented");
|
||||
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
return handle_view(src_split_states);
|
||||
}
|
||||
case GGML_OP_PERMUTE: {
|
||||
return handle_permute(src_split_states);
|
||||
|
|
|
|||
|
|
@ -886,28 +886,49 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
|
|||
|
||||
static enum ggml_backend_meta_split_state llama_meta_device_get_tensor_split(const struct ggml_tensor * tensor, void * userdata) {
|
||||
// attention
|
||||
const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).*");
|
||||
const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight");
|
||||
if (std::regex_match(tensor->name, pattern_qkv_weight)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
|
||||
}
|
||||
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
|
||||
if (std::regex_match(tensor->name, pattern_kv_cache)) {
|
||||
const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias");
|
||||
if (std::regex_match(tensor->name, pattern_qkv_bias)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_attn_out("blk\\.\\d*\\.attn_output.*");
|
||||
if (std::regex_match(tensor->name, pattern_attn_out)) {
|
||||
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
|
||||
if (std::regex_match(tensor->name, pattern_qk_norm)) {
|
||||
return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_STATE_MIRRORED : GGML_BACKEND_SPLIT_STATE_BY_NE1;
|
||||
}
|
||||
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
|
||||
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
||||
if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
|
||||
if (std::regex_match(tensor->name, pattern_attn_out_weight)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
|
||||
if (std::regex_match(tensor->name, pattern_attn_out_bias)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
|
||||
}
|
||||
|
||||
// FFN
|
||||
const std::regex pattern_ffn_up_gate("blk\\.\\d*\\.ffn_(up|gate).*");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate)) {
|
||||
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
|
||||
}
|
||||
const std::regex pattern_ffn_down("blk\\.\\d*\\.ffn_down.*");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_down)) {
|
||||
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_down_weight)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_down_bias)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
|
||||
}
|
||||
|
||||
// output
|
||||
const std::regex pattern_output("output");
|
||||
|
|
|
|||
Loading…
Reference in New Issue