rpc : check src buffer when copying tensor (#16421)
Only dst buffer is guaranteed to be an RPC buffer. Add check for the src one.
This commit is contained in:
parent
898acba681
commit
f39283960b
|
|
@ -631,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
|
||||||
RPC_STATUS_ASSERT(status);
|
RPC_STATUS_ASSERT(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
|
||||||
|
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||||
// check if src and dst are on the same server
|
if (ggml_backend_buffer_is_rpc(src->buffer)) {
|
||||||
ggml_backend_buffer_t src_buffer = src->buffer;
|
// check if src and dst are on the same server
|
||||||
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
ggml_backend_buffer_t src_buffer = src->buffer;
|
||||||
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
||||||
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
||||||
if (src_ctx->sock != dst_ctx->sock) {
|
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
||||||
return false;
|
if (src_ctx->sock != dst_ctx->sock) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||||
|
rpc_msg_copy_tensor_req request;
|
||||||
|
request.src = serialize_tensor(src);
|
||||||
|
request.dst = serialize_tensor(dst);
|
||||||
|
rpc_msg_copy_tensor_rsp response;
|
||||||
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
||||||
|
RPC_STATUS_ASSERT(status);
|
||||||
|
return response.result;
|
||||||
}
|
}
|
||||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
return false;
|
||||||
rpc_msg_copy_tensor_req request;
|
|
||||||
request.src = serialize_tensor(src);
|
|
||||||
request.dst = serialize_tensor(dst);
|
|
||||||
rpc_msg_copy_tensor_rsp response;
|
|
||||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
|
||||||
RPC_STATUS_ASSERT(status);
|
|
||||||
return response.result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue