rpc : check src buffer when copying tensor (#16421)

Only dst buffer is guaranteed to be an RPC buffer. Add check for the src
one.
This commit is contained in:
Radoslav Gerganov 2025-10-04 16:22:45 +03:00 committed by GitHub
parent 898acba681
commit f39283960b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 22 additions and 15 deletions

View File

@ -631,7 +631,12 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
RPC_STATUS_ASSERT(status);
}
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
}
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
if (ggml_backend_buffer_is_rpc(src->buffer)) {
// check if src and dst are on the same server
ggml_backend_buffer_t src_buffer = src->buffer;
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
@ -648,6 +653,8 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.result;
}
return false;
}
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {