fix output pattern

This commit is contained in:
Johannes Gäßler 2026-02-09 22:40:30 +01:00
parent c925563499
commit c531444411
2 changed files with 44 additions and 7 deletions

View File

@ -384,7 +384,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
nb[k] = tensor->nb[k];
}
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
GGML_ASSERT(ne[split_dim] % (n_simple_bufs*ggml_blck_size(tensor->type)) == 0);
GGML_ASSERT(ne[split_dim] % (split_dim == 0 ? n_simple_bufs*ggml_blck_size(tensor->type) : n_simple_bufs) == 0);
ne[split_dim] /= n_simple_bufs;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (tensor->nb[i] > tensor->nb[split_dim]) {
@ -738,11 +738,44 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens
}
static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(ggml_backend_meta_get_split_state(tensor, false) == GGML_BACKEND_SPLIT_STATE_MIRRORED);
const size_t n_backends = ggml_backend_meta_n_backends(backend);
GGML_ASSERT(n_backends >= 1);
ggml_backend_tensor_get_async( // TODO other backends may be more optimal
ggml_backend_meta_simple_backend(backend, 0), ggml_backend_meta_buffer_simple_tensor(tensor, 0), data, offset, size);
GGML_ASSERT(offset == 0);
GGML_ASSERT(ggml_is_contiguous(tensor));
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
switch (split_state) {
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
case GGML_BACKEND_SPLIT_STATE_BY_NE2: {
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
const size_t chunk_size_full = tensor->nb[int(split_state) + 1];
GGML_ASSERT(offset % chunk_size_full == 0);
GGML_ASSERT(size % chunk_size_full == 0);
const int64_t i_start = offset /chunk_size_full;
const int64_t i_stop = (offset + size)/chunk_size_full;
size_t offset_j = 0;
for (size_t j = 0; j < n_backends; j++){
ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1];
for (int64_t i1 = i_start; i1 < i_stop; i1++) {
ggml_backend_tensor_get_async(simple_backend, simple_tensor, (char *) data + i1*chunk_size_full + offset_j, i1*chunk_size_j, chunk_size_j);
}
offset_j += chunk_size_j;
}
GGML_ASSERT(offset_j == chunk_size_full);
} break;
case GGML_BACKEND_SPLIT_STATE_MIRRORED: {
// TODO other simple backend may be better
ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0);
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
static void ggml_backend_meta_synchronize(ggml_backend_t backend) {

View File

@ -931,10 +931,14 @@ static enum ggml_backend_meta_split_state llama_meta_device_get_tensor_split(con
}
// output
const std::regex pattern_output("output");
if (std::regex_match(tensor->name, pattern_output)) {
const std::regex pattern_output_weight("output\\.weight");
if (std::regex_match(tensor->name, pattern_output_weight)) {
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
}
const std::regex pattern_output_bias("output\\.bias");
if (std::regex_match(tensor->name, pattern_output_bias)) {
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
}
// everything else
return GGML_BACKEND_SPLIT_STATE_MIRRORED;