This commit is contained in:
Radoslav Gerganov 2026-02-02 00:18:18 +02:00 committed by GitHub
commit 22c41b44aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 401 additions and 165 deletions

View File

@ -7,7 +7,7 @@ extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 6
#define RPC_PROTO_MINOR_VERSION 7
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16

View File

@ -6,6 +6,9 @@
#include <cinttypes>
#include <string>
#include <vector>
#include <queue>
#include <condition_variable>
#include <future>
#include <memory>
#include <mutex>
#include <unordered_map>
@ -30,6 +33,8 @@
#include <fstream>
#include <filesystem>
#include <algorithm>
#include <atomic>
#include <thread>
static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
@ -107,6 +112,7 @@ enum rpc_cmd {
RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT,
RPC_CMD_GRAPH_RECOMPUTE,
RPC_CMD_NONE,
RPC_CMD_COUNT,
};
@ -261,17 +267,18 @@ struct graph_cache {
std::vector<ggml_tensor> last_graph;
};
class rpc_dispatcher;
struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
graph_cache gc;
std::shared_ptr<rpc_dispatcher> dispatcher;
uint32_t device;
std::string name;
graph_cache gc;
};
struct ggml_backend_rpc_buffer_context {
std::shared_ptr<socket_t> sock;
void * base_ptr;
uint64_t remote_ptr;
std::shared_ptr<rpc_dispatcher> dispatcher;
void * base_ptr;
uint64_t remote_ptr;
};
// RPC helper functions
@ -495,67 +502,267 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
// RPC client-side implementation
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
template <typename T>
class message_queue {
public:
message_queue() {}
bool push(const T &value) {
std::unique_lock<std::mutex> lock(mutex);
if (interrupted) {
return false;
}
queue.push(value);
cvar.notify_all();
return true;
}
bool pop(T* out) {
std::unique_lock<std::mutex> lock(mutex);
cvar.wait(lock, [this] { return !queue.empty() || interrupted; });
if (interrupted) {
return false;
}
*out = queue.front();
queue.pop();
return true;
}
void interrupt() {
std::unique_lock<std::mutex> lock(mutex);
interrupted = true;
lock.unlock();
cvar.notify_all();
}
private:
bool interrupted = false;
std::queue<T> queue;
std::mutex mutex;
std::condition_variable cvar;
};
class rpc_dispatcher {
public:
rpc_dispatcher() {
}
void send(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size);
void send(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size, void * output, size_t output_size);
void send_async(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size);
void send_async(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size, void * output, size_t output_size);
ggml_backend_event_t event_new(ggml_backend_dev_t dev);
void event_free(ggml_backend_event_t event);
void event_synchronize(ggml_backend_event_t event);
void event_record(ggml_backend_event_t event);
void synchronize();
void start(const std::string & endpoint);
void work();
~rpc_dispatcher();
private:
struct rpc_msg {
rpc_cmd cmd;
std::shared_ptr<const void> input;
size_t input_size;
void * output;
size_t output_size;
std::promise<void> completion;
};
using rpc_msg_ptr = std::shared_ptr<rpc_msg>;
using rpc_msg_queue = message_queue<rpc_msg_ptr>;
struct rpc_event {
rpc_msg_ptr msg;
std::shared_future<void> sf;
};
rpc_msg_queue queue;
std::shared_ptr<socket_t> sock;
std::atomic_bool running;
std::thread thread;
};
static void rpc_dispatcher_trampoline(rpc_dispatcher * dispatcher)
{
dispatcher->work();
}
void rpc_dispatcher::send(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size) {
auto msg = std::make_shared<rpc_msg>();
msg->cmd = cmd;
msg->input = input;
msg->input_size = input_size;
msg->output = nullptr;
msg->output_size = 0;
GGML_ASSERT(queue.push(msg));
auto future = msg->completion.get_future();
future.wait();
}
void rpc_dispatcher::send_async(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size) {
auto msg = std::make_shared<rpc_msg>();
msg->cmd = cmd;
msg->input = input;
msg->input_size = input_size;
msg->output = nullptr;
msg->output_size = 0;
GGML_ASSERT(queue.push(msg));
}
void rpc_dispatcher::send(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size, void * output, size_t output_size) {
auto msg = std::make_shared<rpc_msg>();
msg->cmd = cmd;
msg->input = input;
msg->input_size = input_size;
msg->output = output;
msg->output_size = output_size;
GGML_ASSERT(queue.push(msg));
auto future = msg->completion.get_future();
future.wait();
}
void rpc_dispatcher::send_async(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size, void * output, size_t output_size) {
auto msg = std::make_shared<rpc_msg>();
msg->cmd = cmd;
msg->input = input;
msg->input_size = input_size;
msg->output = output;
msg->output_size = output_size;
GGML_ASSERT(queue.push(msg));
}
ggml_backend_event_t rpc_dispatcher::event_new(ggml_backend_dev_t dev) {
rpc_event * ev = new rpc_event;
ev->msg = std::make_shared<rpc_msg>();
ev->msg->cmd = RPC_CMD_NONE;
ev->sf = ev->msg->completion.get_future().share();
GGML_ASSERT(queue.push(ev->msg));
return new ggml_backend_event {
/* .device = */ dev,
/* .context = */ ev,
};
}
void rpc_dispatcher::event_free(ggml_backend_event_t event) {
rpc_event * ev = (rpc_event *)event->context;
delete ev;
}
void rpc_dispatcher::event_synchronize(ggml_backend_event_t event) {
rpc_event * ev = (rpc_event *)event->context;
ev->sf.wait();
}
void rpc_dispatcher::event_record(ggml_backend_event_t event) {
rpc_event * ev = (rpc_event *)event->context;
ev->msg = std::make_shared<rpc_msg>();
ev->msg->cmd = RPC_CMD_NONE;
ev->sf = ev->msg->completion.get_future().share();
GGML_ASSERT(queue.push(ev->msg));
}
void rpc_dispatcher::synchronize() {
// to ensure all messages are processed, submit dummy message and wait for it to complete
auto msg = std::make_shared<rpc_msg>();
msg->cmd = RPC_CMD_NONE;
GGML_ASSERT(queue.push(msg));
msg->completion.get_future().wait();
}
static void check_server_version(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
RPC_STATUS_ASSERT(status);
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
return false;
GGML_ABORT("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
}
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
}
return true;
}
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
static bool initialized = false;
auto it = sockets.find(endpoint);
if (it != sockets.end()) {
if (auto sock = it->second.lock()) {
return sock;
}
}
void rpc_dispatcher::start(const std::string & endpoint) {
static bool win32_init = false;
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
return nullptr;
GGML_ABORT("Failed to parse endpoint: %s\n", endpoint.c_str());
}
#ifdef _WIN32
if (!initialized) {
if (!win32_init) {
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (res != 0) {
return nullptr;
GGML_ABORT("WSAStartup failed");
}
initialized = true;
win32_init = true;
}
#else
GGML_UNUSED(initialized);
GGML_UNUSED(win32_init);
#endif
auto sock = socket_connect(host.c_str(), port);
sock = socket_connect(host.c_str(), port);
if (sock == nullptr) {
return nullptr;
GGML_ABORT("Failed to connect to %s\n", endpoint.c_str());
}
if (!check_server_version(sock)) {
return nullptr;
check_server_version(sock);
LOG_DBG("[rpc_dispatcher] connected to %s, sockfd=%d\n", endpoint.c_str(), sock->fd);
running = true;
thread = std::thread(rpc_dispatcher_trampoline, this);
}
void rpc_dispatcher::work() {
while (running) {
rpc_msg_ptr msg_ptr;
if (!queue.pop(&msg_ptr)) {
break;
}
if (msg_ptr->cmd != RPC_CMD_NONE) {
if (msg_ptr->output) {
bool status = send_rpc_cmd(sock, msg_ptr->cmd, msg_ptr->input.get(), msg_ptr->input_size, msg_ptr->output, msg_ptr->output_size);
RPC_STATUS_ASSERT(status);
} else {
bool status = send_rpc_cmd(sock, msg_ptr->cmd, msg_ptr->input.get(), msg_ptr->input_size);
RPC_STATUS_ASSERT(status);
}
}
msg_ptr->completion.set_value();
}
LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
sockets[endpoint] = sock;
return sock;
}
rpc_dispatcher::~rpc_dispatcher() {
running = false;
queue.interrupt();
sock = nullptr;
if (thread.joinable()) {
thread.join();
}
}
static std::shared_ptr<rpc_dispatcher> get_dispatcher(const std::string & endpoint) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
static std::unordered_map<std::string, std::weak_ptr<rpc_dispatcher>> dispatchers;
auto it = dispatchers.find(endpoint);
if (it != dispatchers.end()) {
if (auto dispatcher = it->second.lock()) {
return dispatcher;
}
}
auto dispatcher = std::make_shared<rpc_dispatcher>();
dispatcher->start(endpoint);
dispatchers[endpoint] = dispatcher;
return dispatcher;
}
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
RPC_STATUS_ASSERT(status);
auto request = std::make_shared<rpc_msg_free_buffer_req>();
request->remote_ptr = ctx->remote_ptr;
ctx->dispatcher->send(RPC_CMD_FREE_BUFFER, request, sizeof(*request));
delete ctx;
}
@ -564,10 +771,10 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
if (ctx->base_ptr != nullptr) {
return ctx->base_ptr;
}
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
auto request = std::make_shared<rpc_msg_buffer_get_base_req>();
request->remote_ptr = ctx->remote_ptr;
rpc_msg_buffer_get_base_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
ctx->dispatcher->send(RPC_CMD_BUFFER_GET_BASE, request, sizeof(*request), &response, sizeof(response));
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
return ctx->base_ptr;
}
@ -623,12 +830,9 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
// Due to bandwidth constraints, we only call the server init tensor functions if necessary.
// In particular, only quantized tensors need padding
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
rpc_msg_init_tensor_req request;
request.tensor = serialize_tensor(tensor);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
RPC_STATUS_ASSERT(status);
auto request = std::make_shared<rpc_msg_init_tensor_req>();
request->tensor = serialize_tensor(tensor);
ctx->dispatcher->send(RPC_CMD_INIT_TENSOR, request, sizeof(*request));
}
return GGML_STATUS_SUCCESS;
}
@ -637,13 +841,12 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_tensor rpc_tensor = serialize_tensor(tensor);
if (size > HASH_THRESHOLD) {
rpc_msg_set_tensor_hash_req request;
request.tensor = rpc_tensor;
request.offset = offset;
request.hash = fnv_hash((const uint8_t*)data, size);
auto request = std::make_shared<rpc_msg_set_tensor_hash_req>();
request->tensor = rpc_tensor;
request->offset = offset;
request->hash = fnv_hash((const uint8_t*)data, size);
rpc_msg_set_tensor_hash_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
ctx->dispatcher->send(RPC_CMD_SET_TENSOR_HASH, request, sizeof(*request), &response, sizeof(response));
if (response.result) {
// the server has the same data, no need to send it
return;
@ -651,22 +854,21 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
}
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
RPC_STATUS_ASSERT(status);
uint8_t * input = new uint8_t[input_size]();
memcpy(input, &rpc_tensor, sizeof(rpc_tensor));
memcpy(input + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input + sizeof(rpc_tensor) + sizeof(offset), data, size);
std::shared_ptr<uint8_t> input_ptr(input, std::default_delete<uint8_t[]>());
ctx->dispatcher->send(RPC_CMD_SET_TENSOR, input_ptr, input_size);
}
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_get_tensor_req request;
request.tensor = serialize_tensor(tensor);
request.offset = offset;
request.size = size;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
RPC_STATUS_ASSERT(status);
auto request = std::make_shared<rpc_msg_get_tensor_req>();
request->tensor = serialize_tensor(tensor);
request->offset = offset;
request->size = size;
ctx->dispatcher->send(RPC_CMD_GET_TENSOR, request, sizeof(*request), data, size);
}
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@ -676,16 +878,15 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
ggml_backend_buffer_t dst_buffer = dst->buffer;
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
if (src_ctx->sock != dst_ctx->sock) {
if (src_ctx->dispatcher != dst_ctx->dispatcher) {
return false;
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_copy_tensor_req request;
request.src = serialize_tensor(src);
request.dst = serialize_tensor(dst);
auto request = std::make_shared<rpc_msg_copy_tensor_req>();
request->src = serialize_tensor(src);
request->dst = serialize_tensor(dst);
rpc_msg_copy_tensor_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
ctx->dispatcher->send(RPC_CMD_COPY_TENSOR, request, sizeof(*request), &response, sizeof(response));
return response.result;
}
return false;
@ -693,9 +894,10 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
RPC_STATUS_ASSERT(status);
auto request = std::make_shared<rpc_msg_buffer_clear_req>();
request->remote_ptr = ctx->remote_ptr;
request->value = value;
ctx->dispatcher->send(RPC_CMD_BUFFER_CLEAR, request, sizeof(*request));
}
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
@ -717,15 +919,17 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
auto request = std::make_shared<rpc_msg_alloc_buffer_req>();
request->device = buft_ctx->device;
request->size = size;
rpc_msg_alloc_buffer_rsp response;
auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
auto dispatcher = get_dispatcher(buft_ctx->endpoint);
dispatcher->send(RPC_CMD_ALLOC_BUFFER, request, sizeof(*request), &response, sizeof(response));
if (response.remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
new ggml_backend_rpc_buffer_context{dispatcher, nullptr, response.remote_ptr},
response.remote_size);
return buffer;
} else {
@ -733,11 +937,11 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
}
}
static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
rpc_msg_get_alignment_req request = {device};
static size_t get_alignment(const std::shared_ptr<rpc_dispatcher> & dispatcher, uint32_t device) {
auto request = std::make_shared<rpc_msg_get_alignment_req>();
request->device = device;
rpc_msg_get_alignment_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
dispatcher->send(RPC_CMD_GET_ALIGNMENT, request, sizeof(*request), &response, sizeof(response));
return response.alignment;
}
@ -746,11 +950,11 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
return buft_ctx->alignment;
}
static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
rpc_msg_get_max_size_req request = {device};
static size_t get_max_size(const std::shared_ptr<rpc_dispatcher> & dispatcher, uint32_t device) {
auto request = std::make_shared<rpc_msg_get_max_size_req>();
request->device = device;
rpc_msg_get_max_size_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
dispatcher->send(RPC_CMD_GET_MAX_SIZE, request, sizeof(*request), &response, sizeof(response));
return response.max_size;
}
@ -773,23 +977,20 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
if (rpc_get) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
auto sock = get_socket(buft_ctx->endpoint);
auto dispatcher = get_dispatcher(buft_ctx->endpoint);
rpc_msg_get_alloc_size_req request = {
/*.device =*/ buft_ctx->device,
/*.tensor =*/ serialize_tensor(tensor),
/*.srcs =*/ {},
};
auto request = std::make_shared<rpc_msg_get_alloc_size_req>();
request->device = buft_ctx->device;
request->tensor = serialize_tensor(tensor);
// .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
for (int i = 0; i < GGML_MAX_SRC; i++) {
request.srcs[i] = serialize_tensor(tensor->src[i]);
request->srcs[i] = serialize_tensor(tensor->src[i]);
}
// TODO: cache the alloc responses to avoid extra RPC calls?
rpc_msg_get_alloc_size_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
dispatcher->send(RPC_CMD_GET_ALLOC_SIZE, request, sizeof(*request), &response, sizeof(response));
return response.alloc_size;
}
@ -818,9 +1019,44 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) {
delete backend;
}
static void ggml_backend_rpc_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
rpc_tensor rpc_tensor = serialize_tensor(tensor);
if (size > HASH_THRESHOLD) {
auto request = std::make_shared<rpc_msg_set_tensor_hash_req>();
request->tensor = rpc_tensor;
request->offset = offset;
request->hash = fnv_hash((const uint8_t*)data, size);
rpc_msg_set_tensor_hash_rsp response;
// TODO: make this async
ctx->dispatcher->send(RPC_CMD_SET_TENSOR_HASH, request, sizeof(*request), &response, sizeof(response));
if (response.result) {
// the server has the same data, no need to send it
return;
}
}
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
uint8_t * input = new uint8_t[input_size]();
memcpy(input, &rpc_tensor, sizeof(rpc_tensor));
memcpy(input + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input + sizeof(rpc_tensor) + sizeof(offset), data, size);
std::shared_ptr<uint8_t> input_ptr(input, std::default_delete<uint8_t[]>());
ctx->dispatcher->send_async(RPC_CMD_SET_TENSOR, input_ptr, input_size);
}
static void ggml_backend_rpc_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
auto request = std::make_shared<rpc_msg_get_tensor_req>();
request->tensor = serialize_tensor(tensor);
request->offset = offset;
request->size = size;
ctx->dispatcher->send_async(RPC_CMD_GET_TENSOR, request, sizeof(*request), data, size);
}
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend);
// this is no-op because we don't have any async operations
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
rpc_ctx->dispatcher->synchronize();
}
static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
@ -838,7 +1074,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
tensors.push_back(serialize_tensor(tensor));
}
static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
static uint8_t * serialize_graph(uint32_t device, const ggml_cgraph * cgraph, size_t * output_size) {
uint32_t n_nodes = cgraph->n_nodes;
std::vector<rpc_tensor> tensors;
std::unordered_set<ggml_tensor*> visited;
@ -848,9 +1084,9 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
// serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
uint32_t n_tensors = tensors.size();
int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
output.resize(output_size, 0);
uint8_t * dest = output.data();
*output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
uint8_t * output = new uint8_t[*output_size]();
uint8_t * dest = output;
memcpy(dest, &device, sizeof(device));
dest += sizeof(device);
memcpy(dest, &n_nodes, sizeof(n_nodes));
@ -863,6 +1099,7 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
dest += sizeof(n_tensors);
rpc_tensor * out_tensors = (rpc_tensor *)dest;
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
return output;
}
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
@ -871,27 +1108,35 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
GGML_ASSERT(cgraph->n_nodes > 0);
bool reuse = rpc_ctx->gc.is_cached(cgraph);
if (reuse) {
rpc_msg_graph_recompute_req request;
request.device = rpc_ctx->device;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
RPC_STATUS_ASSERT(status);
auto request = std::make_shared<rpc_msg_graph_recompute_req>();
request->device = rpc_ctx->device;
rpc_ctx->dispatcher->send_async(RPC_CMD_GRAPH_RECOMPUTE, request, sizeof(*request));
} else {
rpc_ctx->gc.add(cgraph);
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
RPC_STATUS_ASSERT(status);
size_t input_size = 0;
uint8_t * input = serialize_graph(rpc_ctx->device, cgraph, &input_size);
std::shared_ptr<uint8_t> input_ptr(input, std::default_delete<uint8_t[]>());
rpc_ctx->dispatcher->send_async(RPC_CMD_GRAPH_COMPUTE, input_ptr, input_size);
}
return GGML_STATUS_SUCCESS;
}
static void ggml_backend_rpc_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
rpc_ctx->dispatcher->event_record(event);
}
static void ggml_backend_rpc_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
// this is noop for RPC as we have a single stream
GGML_UNUSED(backend);
GGML_UNUSED(event);
}
static ggml_backend_i ggml_backend_rpc_interface = {
/* .get_name = */ ggml_backend_rpc_name,
/* .free = */ ggml_backend_rpc_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .set_tensor_async = */ ggml_backend_rpc_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_rpc_get_tensor_async,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ ggml_backend_rpc_synchronize,
/* .graph_plan_create = */ NULL,
@ -899,8 +1144,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .event_record = */ ggml_backend_rpc_event_record,
/* .event_wait = */ ggml_backend_rpc_event_wait,
/* .graph_optimize = */ NULL,
};
@ -914,13 +1159,9 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
if (it != buft_map.end()) {
return it->second;
}
auto sock = get_socket(endpoint);
if (sock == nullptr) {
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
return nullptr;
}
size_t alignment = get_alignment(sock, device);
size_t max_size = get_max_size(sock, device);
auto dispatcher = get_dispatcher(endpoint);
size_t alignment = get_alignment(dispatcher, device);
size_t max_size = get_max_size(dispatcher, device);
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .endpoint = */ endpoint,
/* .device = */ device,
@ -940,11 +1181,12 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
auto dispatcher = get_dispatcher(endpoint);
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name,
/* .gc = */ {},
/* .dispatcher = */ dispatcher,
/* .device = */ device,
/* .name = */ dev_name,
/* .gc = */ {},
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
@ -960,26 +1202,16 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
}
static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
rpc_msg_get_device_memory_req request;
request.device = device;
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
auto dispatcher = get_dispatcher(endpoint);
auto request = std::make_shared<rpc_msg_get_device_memory_req>();
request->device = device;
rpc_msg_get_device_memory_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
dispatcher->send(RPC_CMD_GET_DEVICE_MEMORY, request, sizeof(*request), &response, sizeof(response));
*free = response.free_mem;
*total = response.total_mem;
}
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
auto sock = get_socket(endpoint);
if (sock == nullptr) {
*free = 0;
*total = 0;
return;
}
get_device_memory(sock, device, free, total);
}
// RPC server-side implementation
class rpc_server {
@ -1701,9 +1933,6 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!server.free_buffer(request)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break;
}
case RPC_CMD_BUFFER_CLEAR: {
@ -1714,9 +1943,6 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!server.buffer_clear(request)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break;
}
case RPC_CMD_SET_TENSOR: {
@ -1751,9 +1977,6 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!server.init_tensor(request)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break;
}
case RPC_CMD_GET_TENSOR: {
@ -1941,10 +2164,10 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
props->type = ggml_backend_rpc_device_get_type(dev);
ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* .async = */ false,
/* .async = */ true,
/* .host_buffer = */ false,
/* .buffer_from_host_ptr = */ false,
/* .events = */ false,
/* .events = */ true,
};
}
@ -1980,6 +2203,24 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
}
static ggml_backend_event_t ggml_backend_rpc_device_event_new(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
auto dispatcher = get_dispatcher(ctx->endpoint);
return dispatcher->event_new(dev);
}
static void ggml_backend_rpc_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
auto dispatcher = get_dispatcher(ctx->endpoint);
dispatcher->event_free(event);
}
static void ggml_backend_rpc_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
auto dispatcher = get_dispatcher(ctx->endpoint);
dispatcher->event_synchronize(event);
}
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
/* .get_name = */ ggml_backend_rpc_device_get_name,
/* .get_description = */ ggml_backend_rpc_device_get_description,
@ -1993,9 +2234,9 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
/* .supports_op = */ ggml_backend_rpc_device_supports_op,
/* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
/* .offload_op = */ NULL,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
/* .event_new = */ ggml_backend_rpc_device_event_new,
/* .event_free = */ ggml_backend_rpc_device_event_free,
/* .event_synchronize = */ ggml_backend_rpc_device_event_synchronize,
};
// backend reg interface
@ -2055,14 +2296,9 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
}
static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
auto sock = get_socket(endpoint);
if (sock == nullptr) {
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
return 0;
}
auto dispatcher = get_dispatcher(endpoint);
rpc_msg_device_count_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
RPC_STATUS_ASSERT(status);
dispatcher->send(RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
return response.device_count;
}