ggml-virtgpu: protect the use of the shared memory to transfer data
This commit is contained in:
parent
92390ad9f6
commit
fcc6890710
|
|
@ -18,9 +18,14 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {
|
|||
|
||||
virtgpu_shmem temp_shmem; // Local storage for large buffers
|
||||
virtgpu_shmem * shmem = &temp_shmem;
|
||||
bool using_shared_shmem = false;
|
||||
|
||||
if (cgraph_size <= gpu->data_shmem.mmap_size) {
|
||||
// prefer the init-time allocated page, if large enough
|
||||
// Lock mutex before using shared data_shmem buffer
|
||||
if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) {
|
||||
GGML_ABORT("Failed to lock data_shmem mutex");
|
||||
}
|
||||
using_shared_shmem = true;
|
||||
shmem = &gpu->data_shmem;
|
||||
} else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) {
|
||||
GGML_ABORT("Couldn't allocate the guest-host shared buffer");
|
||||
|
|
@ -42,7 +47,10 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {
|
|||
|
||||
remote_call_finish(gpu, encoder, decoder);
|
||||
|
||||
if (shmem != &gpu->data_shmem) {
|
||||
// Unlock mutex before cleanup
|
||||
if (using_shared_shmem) {
|
||||
mtx_unlock(&gpu->data_shmem_mutex);
|
||||
} else {
|
||||
virtgpu_shmem_destroy(gpu, shmem);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,9 +36,14 @@ void apir_buffer_set_tensor(virtgpu * gpu,
|
|||
|
||||
virtgpu_shmem temp_shmem; // Local storage for large buffers
|
||||
virtgpu_shmem * shmem = &temp_shmem;
|
||||
bool using_shared_shmem = false;
|
||||
|
||||
if (size <= gpu->data_shmem.mmap_size) {
|
||||
// prefer the init-time allocated page, if large enough
|
||||
// Lock mutex before using shared data_shmem buffer
|
||||
if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) {
|
||||
GGML_ABORT("Failed to lock data_shmem mutex");
|
||||
}
|
||||
using_shared_shmem = true;
|
||||
shmem = &gpu->data_shmem;
|
||||
|
||||
} else if (virtgpu_shmem_create(gpu, size, shmem)) {
|
||||
|
|
@ -55,7 +60,10 @@ void apir_buffer_set_tensor(virtgpu * gpu,
|
|||
|
||||
remote_call_finish(gpu, encoder, decoder);
|
||||
|
||||
if (shmem != &gpu->data_shmem) {
|
||||
// Unlock mutex before cleanup
|
||||
if (using_shared_shmem) {
|
||||
mtx_unlock(&gpu->data_shmem_mutex);
|
||||
} else {
|
||||
virtgpu_shmem_destroy(gpu, shmem);
|
||||
}
|
||||
|
||||
|
|
@ -79,9 +87,14 @@ void apir_buffer_get_tensor(virtgpu * gpu,
|
|||
|
||||
virtgpu_shmem temp_shmem; // Local storage for large buffers
|
||||
virtgpu_shmem * shmem = &temp_shmem;
|
||||
bool using_shared_shmem = false;
|
||||
|
||||
if (size <= gpu->data_shmem.mmap_size) {
|
||||
// prefer the init-time allocated page, if large enough
|
||||
// Lock mutex before using shared data_shmem buffer
|
||||
if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) {
|
||||
GGML_ABORT("Failed to lock data_shmem mutex");
|
||||
}
|
||||
using_shared_shmem = true;
|
||||
shmem = &gpu->data_shmem;
|
||||
|
||||
} else if (virtgpu_shmem_create(gpu, size, shmem)) {
|
||||
|
|
@ -98,7 +111,10 @@ void apir_buffer_get_tensor(virtgpu * gpu,
|
|||
|
||||
remote_call_finish(gpu, encoder, decoder);
|
||||
|
||||
if (shmem != &gpu->data_shmem) {
|
||||
// Unlock mutex before cleanup
|
||||
if (using_shared_shmem) {
|
||||
mtx_unlock(&gpu->data_shmem_mutex);
|
||||
} else {
|
||||
virtgpu_shmem_destroy(gpu, shmem);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -149,6 +149,13 @@ virtgpu * create_virtgpu() {
|
|||
gpu->use_apir_capset = getenv("GGML_REMOTING_USE_APIR_CAPSET") != nullptr;
|
||||
util_sparse_array_init(&gpu->shmem_array, sizeof(virtgpu_shmem), 1024);
|
||||
|
||||
// Initialize mutex to protect shared data_shmem buffer
|
||||
if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) {
|
||||
delete gpu;
|
||||
GGML_ABORT("%s: failed to initialize data_shmem mutex", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (virtgpu_open(gpu) != APIR_SUCCESS) {
|
||||
GGML_ABORT("%s: failed to open the virtgpu device", __func__);
|
||||
return NULL;
|
||||
|
|
|
|||
|
|
@ -74,6 +74,9 @@ struct virtgpu {
|
|||
virtgpu_shmem reply_shmem;
|
||||
virtgpu_shmem data_shmem;
|
||||
|
||||
/* Mutex to protect shared data_shmem buffer from concurrent access */
|
||||
mtx_t data_shmem_mutex;
|
||||
|
||||
/* Cached device information to prevent memory leaks and race conditions */
|
||||
struct {
|
||||
char * description;
|
||||
|
|
|
|||
Loading…
Reference in New Issue