address review comments
This commit is contained in:
parent
df27d80ae3
commit
9a67778451
|
|
@ -862,32 +862,6 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|||
ctx->dispatcher->send(RPC_CMD_SET_TENSOR, input_ptr, input_size);
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_buffer_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
||||
if (size > HASH_THRESHOLD) {
|
||||
auto request = std::make_shared<rpc_msg_set_tensor_hash_req>();
|
||||
request->tensor = rpc_tensor;
|
||||
request->offset = offset;
|
||||
request->hash = fnv_hash((const uint8_t*)data, size);
|
||||
rpc_msg_set_tensor_hash_rsp response;
|
||||
// TODO: make this async
|
||||
ctx->dispatcher->send(RPC_CMD_SET_TENSOR_HASH, request, sizeof(*request), &response, sizeof(response));
|
||||
if (response.result) {
|
||||
// the server has the same data, no need to send it
|
||||
return;
|
||||
}
|
||||
}
|
||||
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
|
||||
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
||||
uint8_t * input = new uint8_t[input_size]();
|
||||
memcpy(input, &rpc_tensor, sizeof(rpc_tensor));
|
||||
memcpy(input + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
||||
std::shared_ptr<uint8_t> input_ptr(input, std::default_delete<uint8_t[]>());
|
||||
ctx->dispatcher->send_async(RPC_CMD_SET_TENSOR, input_ptr, input_size);
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
auto request = std::make_shared<rpc_msg_get_tensor_req>();
|
||||
|
|
@ -897,15 +871,6 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
|
|||
ctx->dispatcher->send(RPC_CMD_GET_TENSOR, request, sizeof(*request), data, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_buffer_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
auto request = std::make_shared<rpc_msg_get_tensor_req>();
|
||||
request->tensor = serialize_tensor(tensor);
|
||||
request->offset = offset;
|
||||
request->size = size;
|
||||
ctx->dispatcher->send_async(RPC_CMD_GET_TENSOR, request, sizeof(*request), data, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
if (ggml_backend_buffer_is_rpc(src->buffer)) {
|
||||
// check if src and dst are on the same server
|
||||
|
|
@ -1054,6 +1019,41 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
|||
delete backend;
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
||||
if (size > HASH_THRESHOLD) {
|
||||
auto request = std::make_shared<rpc_msg_set_tensor_hash_req>();
|
||||
request->tensor = rpc_tensor;
|
||||
request->offset = offset;
|
||||
request->hash = fnv_hash((const uint8_t*)data, size);
|
||||
rpc_msg_set_tensor_hash_rsp response;
|
||||
// TODO: make this async
|
||||
ctx->dispatcher->send(RPC_CMD_SET_TENSOR_HASH, request, sizeof(*request), &response, sizeof(response));
|
||||
if (response.result) {
|
||||
// the server has the same data, no need to send it
|
||||
return;
|
||||
}
|
||||
}
|
||||
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
|
||||
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
||||
uint8_t * input = new uint8_t[input_size]();
|
||||
memcpy(input, &rpc_tensor, sizeof(rpc_tensor));
|
||||
memcpy(input + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
||||
std::shared_ptr<uint8_t> input_ptr(input, std::default_delete<uint8_t[]>());
|
||||
ctx->dispatcher->send_async(RPC_CMD_SET_TENSOR, input_ptr, input_size);
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
auto request = std::make_shared<rpc_msg_get_tensor_req>();
|
||||
request->tensor = serialize_tensor(tensor);
|
||||
request->offset = offset;
|
||||
request->size = size;
|
||||
ctx->dispatcher->send_async(RPC_CMD_GET_TENSOR, request, sizeof(*request), data, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
rpc_ctx->dispatcher->synchronize();
|
||||
|
|
@ -1126,16 +1126,16 @@ static void ggml_backend_rpc_event_record(ggml_backend_t backend, ggml_backend_e
|
|||
rpc_ctx->dispatcher->event_record(event);
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_event_wait(ggml_backend_t dev, ggml_backend_event_t event) {
|
||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)dev->context;
|
||||
static void ggml_backend_rpc_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
|
||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
ctx->dispatcher->event_synchronize(event);
|
||||
}
|
||||
|
||||
static ggml_backend_i ggml_backend_rpc_interface = {
|
||||
/* .get_name = */ ggml_backend_rpc_name,
|
||||
/* .free = */ ggml_backend_rpc_free,
|
||||
/* .set_tensor_async = */ ggml_backend_rpc_buffer_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_rpc_buffer_get_tensor_async,
|
||||
/* .set_tensor_async = */ ggml_backend_rpc_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_rpc_get_tensor_async,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ ggml_backend_rpc_synchronize,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
|
|
|
|||
Loading…
Reference in New Issue