|
|
|
|
@ -6,6 +6,9 @@
|
|
|
|
|
#include <cinttypes>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <queue>
|
|
|
|
|
#include <condition_variable>
|
|
|
|
|
#include <future>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <mutex>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
@ -30,6 +33,8 @@
|
|
|
|
|
#include <fstream>
|
|
|
|
|
#include <filesystem>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <atomic>
|
|
|
|
|
#include <thread>
|
|
|
|
|
|
|
|
|
|
static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
|
|
|
|
|
|
|
|
|
|
@ -107,6 +112,7 @@ enum rpc_cmd {
|
|
|
|
|
RPC_CMD_HELLO,
|
|
|
|
|
RPC_CMD_DEVICE_COUNT,
|
|
|
|
|
RPC_CMD_GRAPH_RECOMPUTE,
|
|
|
|
|
RPC_CMD_NONE,
|
|
|
|
|
RPC_CMD_COUNT,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
@ -261,17 +267,18 @@ struct graph_cache {
|
|
|
|
|
std::vector<ggml_tensor> last_graph;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class rpc_dispatcher;
|
|
|
|
|
struct ggml_backend_rpc_context {
|
|
|
|
|
std::string endpoint;
|
|
|
|
|
uint32_t device;
|
|
|
|
|
std::string name;
|
|
|
|
|
graph_cache gc;
|
|
|
|
|
std::shared_ptr<rpc_dispatcher> dispatcher;
|
|
|
|
|
uint32_t device;
|
|
|
|
|
std::string name;
|
|
|
|
|
graph_cache gc;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ggml_backend_rpc_buffer_context {
|
|
|
|
|
std::shared_ptr<socket_t> sock;
|
|
|
|
|
void * base_ptr;
|
|
|
|
|
uint64_t remote_ptr;
|
|
|
|
|
std::shared_ptr<rpc_dispatcher> dispatcher;
|
|
|
|
|
void * base_ptr;
|
|
|
|
|
uint64_t remote_ptr;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// RPC helper functions
|
|
|
|
|
@ -495,67 +502,267 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
|
|
|
|
|
|
|
|
// RPC client-side implementation
|
|
|
|
|
|
|
|
|
|
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class message_queue {
|
|
|
|
|
public:
|
|
|
|
|
message_queue() {}
|
|
|
|
|
|
|
|
|
|
bool push(const T &value) {
|
|
|
|
|
std::unique_lock<std::mutex> lock(mutex);
|
|
|
|
|
if (interrupted) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
queue.push(value);
|
|
|
|
|
cvar.notify_all();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool pop(T* out) {
|
|
|
|
|
std::unique_lock<std::mutex> lock(mutex);
|
|
|
|
|
cvar.wait(lock, [this] { return !queue.empty() || interrupted; });
|
|
|
|
|
if (interrupted) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
*out = queue.front();
|
|
|
|
|
queue.pop();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void interrupt() {
|
|
|
|
|
std::unique_lock<std::mutex> lock(mutex);
|
|
|
|
|
interrupted = true;
|
|
|
|
|
lock.unlock();
|
|
|
|
|
cvar.notify_all();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool interrupted = false;
|
|
|
|
|
std::queue<T> queue;
|
|
|
|
|
std::mutex mutex;
|
|
|
|
|
std::condition_variable cvar;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class rpc_dispatcher {
|
|
|
|
|
public:
|
|
|
|
|
rpc_dispatcher() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void send(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size);
|
|
|
|
|
void send(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size, void * output, size_t output_size);
|
|
|
|
|
void send_async(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size);
|
|
|
|
|
void send_async(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size, void * output, size_t output_size);
|
|
|
|
|
|
|
|
|
|
ggml_backend_event_t event_new(ggml_backend_dev_t dev);
|
|
|
|
|
void event_free(ggml_backend_event_t event);
|
|
|
|
|
void event_synchronize(ggml_backend_event_t event);
|
|
|
|
|
void event_record(ggml_backend_event_t event);
|
|
|
|
|
void synchronize();
|
|
|
|
|
|
|
|
|
|
void start(const std::string & endpoint);
|
|
|
|
|
void work();
|
|
|
|
|
|
|
|
|
|
~rpc_dispatcher();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
struct rpc_msg {
|
|
|
|
|
rpc_cmd cmd;
|
|
|
|
|
std::shared_ptr<const void> input;
|
|
|
|
|
size_t input_size;
|
|
|
|
|
void * output;
|
|
|
|
|
size_t output_size;
|
|
|
|
|
std::promise<void> completion;
|
|
|
|
|
};
|
|
|
|
|
using rpc_msg_ptr = std::shared_ptr<rpc_msg>;
|
|
|
|
|
using rpc_msg_queue = message_queue<rpc_msg_ptr>;
|
|
|
|
|
struct rpc_event {
|
|
|
|
|
rpc_msg_ptr msg;
|
|
|
|
|
std::shared_future<void> sf;
|
|
|
|
|
};
|
|
|
|
|
rpc_msg_queue queue;
|
|
|
|
|
std::shared_ptr<socket_t> sock;
|
|
|
|
|
std::atomic_bool running;
|
|
|
|
|
std::thread thread;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static void rpc_dispatcher_trampoline(rpc_dispatcher * dispatcher)
|
|
|
|
|
{
|
|
|
|
|
dispatcher->work();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rpc_dispatcher::send(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size) {
|
|
|
|
|
auto msg = std::make_shared<rpc_msg>();
|
|
|
|
|
msg->cmd = cmd;
|
|
|
|
|
msg->input = input;
|
|
|
|
|
msg->input_size = input_size;
|
|
|
|
|
msg->output = nullptr;
|
|
|
|
|
msg->output_size = 0;
|
|
|
|
|
GGML_ASSERT(queue.push(msg));
|
|
|
|
|
auto future = msg->completion.get_future();
|
|
|
|
|
future.wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rpc_dispatcher::send_async(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size) {
|
|
|
|
|
auto msg = std::make_shared<rpc_msg>();
|
|
|
|
|
msg->cmd = cmd;
|
|
|
|
|
msg->input = input;
|
|
|
|
|
msg->input_size = input_size;
|
|
|
|
|
msg->output = nullptr;
|
|
|
|
|
msg->output_size = 0;
|
|
|
|
|
GGML_ASSERT(queue.push(msg));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rpc_dispatcher::send(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size, void * output, size_t output_size) {
|
|
|
|
|
auto msg = std::make_shared<rpc_msg>();
|
|
|
|
|
msg->cmd = cmd;
|
|
|
|
|
msg->input = input;
|
|
|
|
|
msg->input_size = input_size;
|
|
|
|
|
msg->output = output;
|
|
|
|
|
msg->output_size = output_size;
|
|
|
|
|
GGML_ASSERT(queue.push(msg));
|
|
|
|
|
auto future = msg->completion.get_future();
|
|
|
|
|
future.wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rpc_dispatcher::send_async(enum rpc_cmd cmd, std::shared_ptr<const void> input, size_t input_size, void * output, size_t output_size) {
|
|
|
|
|
auto msg = std::make_shared<rpc_msg>();
|
|
|
|
|
msg->cmd = cmd;
|
|
|
|
|
msg->input = input;
|
|
|
|
|
msg->input_size = input_size;
|
|
|
|
|
msg->output = output;
|
|
|
|
|
msg->output_size = output_size;
|
|
|
|
|
GGML_ASSERT(queue.push(msg));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_backend_event_t rpc_dispatcher::event_new(ggml_backend_dev_t dev) {
|
|
|
|
|
rpc_event * ev = new rpc_event;
|
|
|
|
|
ev->msg = std::make_shared<rpc_msg>();
|
|
|
|
|
ev->msg->cmd = RPC_CMD_NONE;
|
|
|
|
|
ev->sf = ev->msg->completion.get_future().share();
|
|
|
|
|
GGML_ASSERT(queue.push(ev->msg));
|
|
|
|
|
return new ggml_backend_event {
|
|
|
|
|
/* .device = */ dev,
|
|
|
|
|
/* .context = */ ev,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rpc_dispatcher::event_free(ggml_backend_event_t event) {
|
|
|
|
|
rpc_event * ev = (rpc_event *)event->context;
|
|
|
|
|
delete ev;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rpc_dispatcher::event_synchronize(ggml_backend_event_t event) {
|
|
|
|
|
rpc_event * ev = (rpc_event *)event->context;
|
|
|
|
|
ev->sf.wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rpc_dispatcher::event_record(ggml_backend_event_t event) {
|
|
|
|
|
rpc_event * ev = (rpc_event *)event->context;
|
|
|
|
|
ev->msg = std::make_shared<rpc_msg>();
|
|
|
|
|
ev->msg->cmd = RPC_CMD_NONE;
|
|
|
|
|
ev->sf = ev->msg->completion.get_future().share();
|
|
|
|
|
GGML_ASSERT(queue.push(ev->msg));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rpc_dispatcher::synchronize() {
|
|
|
|
|
// to ensure all messages are processed, submit dummy message and wait for it to complete
|
|
|
|
|
auto msg = std::make_shared<rpc_msg>();
|
|
|
|
|
msg->cmd = RPC_CMD_NONE;
|
|
|
|
|
GGML_ASSERT(queue.push(msg));
|
|
|
|
|
msg->completion.get_future().wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void check_server_version(const std::shared_ptr<socket_t> & sock) {
|
|
|
|
|
rpc_msg_hello_rsp response;
|
|
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
|
|
|
|
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
|
|
|
return false;
|
|
|
|
|
GGML_ABORT("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
|
|
|
}
|
|
|
|
|
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
|
|
|
|
|
GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
|
|
static std::mutex mutex;
|
|
|
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
|
|
|
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
|
|
|
|
|
static bool initialized = false;
|
|
|
|
|
|
|
|
|
|
auto it = sockets.find(endpoint);
|
|
|
|
|
if (it != sockets.end()) {
|
|
|
|
|
if (auto sock = it->second.lock()) {
|
|
|
|
|
return sock;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void rpc_dispatcher::start(const std::string & endpoint) {
|
|
|
|
|
static bool win32_init = false;
|
|
|
|
|
std::string host;
|
|
|
|
|
int port;
|
|
|
|
|
if (!parse_endpoint(endpoint, host, port)) {
|
|
|
|
|
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
|
|
|
|
|
return nullptr;
|
|
|
|
|
GGML_ABORT("Failed to parse endpoint: %s\n", endpoint.c_str());
|
|
|
|
|
}
|
|
|
|
|
#ifdef _WIN32
|
|
|
|
|
if (!initialized) {
|
|
|
|
|
if (!win32_init) {
|
|
|
|
|
WSADATA wsaData;
|
|
|
|
|
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
|
|
|
|
if (res != 0) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
GGML_ABORT("WSAStartup failed");
|
|
|
|
|
}
|
|
|
|
|
initialized = true;
|
|
|
|
|
win32_init = true;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
GGML_UNUSED(initialized);
|
|
|
|
|
GGML_UNUSED(win32_init);
|
|
|
|
|
#endif
|
|
|
|
|
auto sock = socket_connect(host.c_str(), port);
|
|
|
|
|
sock = socket_connect(host.c_str(), port);
|
|
|
|
|
if (sock == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
GGML_ABORT("Failed to connect to %s\n", endpoint.c_str());
|
|
|
|
|
}
|
|
|
|
|
if (!check_server_version(sock)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
check_server_version(sock);
|
|
|
|
|
LOG_DBG("[rpc_dispatcher] connected to %s, sockfd=%d\n", endpoint.c_str(), sock->fd);
|
|
|
|
|
running = true;
|
|
|
|
|
thread = std::thread(rpc_dispatcher_trampoline, this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rpc_dispatcher::work() {
|
|
|
|
|
while (running) {
|
|
|
|
|
rpc_msg_ptr msg_ptr;
|
|
|
|
|
if (!queue.pop(&msg_ptr)) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
if (msg_ptr->cmd != RPC_CMD_NONE) {
|
|
|
|
|
if (msg_ptr->output) {
|
|
|
|
|
bool status = send_rpc_cmd(sock, msg_ptr->cmd, msg_ptr->input.get(), msg_ptr->input_size, msg_ptr->output, msg_ptr->output_size);
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
} else {
|
|
|
|
|
bool status = send_rpc_cmd(sock, msg_ptr->cmd, msg_ptr->input.get(), msg_ptr->input_size);
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
msg_ptr->completion.set_value();
|
|
|
|
|
}
|
|
|
|
|
LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
|
|
|
|
sockets[endpoint] = sock;
|
|
|
|
|
return sock;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rpc_dispatcher::~rpc_dispatcher() {
|
|
|
|
|
running = false;
|
|
|
|
|
queue.interrupt();
|
|
|
|
|
sock = nullptr;
|
|
|
|
|
if (thread.joinable()) {
|
|
|
|
|
thread.join();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<rpc_dispatcher> get_dispatcher(const std::string & endpoint) {
|
|
|
|
|
static std::mutex mutex;
|
|
|
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
|
|
|
static std::unordered_map<std::string, std::weak_ptr<rpc_dispatcher>> dispatchers;
|
|
|
|
|
|
|
|
|
|
auto it = dispatchers.find(endpoint);
|
|
|
|
|
if (it != dispatchers.end()) {
|
|
|
|
|
if (auto dispatcher = it->second.lock()) {
|
|
|
|
|
return dispatcher;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dispatcher = std::make_shared<rpc_dispatcher>();
|
|
|
|
|
dispatcher->start(endpoint);
|
|
|
|
|
dispatchers[endpoint] = dispatcher;
|
|
|
|
|
return dispatcher;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
|
|
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
|
|
|
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
|
|
|
|
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
auto request = std::make_shared<rpc_msg_free_buffer_req>();
|
|
|
|
|
request->remote_ptr = ctx->remote_ptr;
|
|
|
|
|
ctx->dispatcher->send(RPC_CMD_FREE_BUFFER, request, sizeof(*request));
|
|
|
|
|
delete ctx;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -564,10 +771,10 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
|
|
|
if (ctx->base_ptr != nullptr) {
|
|
|
|
|
return ctx->base_ptr;
|
|
|
|
|
}
|
|
|
|
|
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
|
|
|
|
auto request = std::make_shared<rpc_msg_buffer_get_base_req>();
|
|
|
|
|
request->remote_ptr = ctx->remote_ptr;
|
|
|
|
|
rpc_msg_buffer_get_base_rsp response;
|
|
|
|
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
ctx->dispatcher->send(RPC_CMD_BUFFER_GET_BASE, request, sizeof(*request), &response, sizeof(response));
|
|
|
|
|
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
|
|
|
|
return ctx->base_ptr;
|
|
|
|
|
}
|
|
|
|
|
@ -623,12 +830,9 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
|
|
|
|
|
// Due to bandwidth constraints, we only call the server init tensor functions if necessary.
|
|
|
|
|
// In particular, only quantized tensors need padding
|
|
|
|
|
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
|
|
|
|
|
rpc_msg_init_tensor_req request;
|
|
|
|
|
|
|
|
|
|
request.tensor = serialize_tensor(tensor);
|
|
|
|
|
|
|
|
|
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
auto request = std::make_shared<rpc_msg_init_tensor_req>();
|
|
|
|
|
request->tensor = serialize_tensor(tensor);
|
|
|
|
|
ctx->dispatcher->send(RPC_CMD_INIT_TENSOR, request, sizeof(*request));
|
|
|
|
|
}
|
|
|
|
|
return GGML_STATUS_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
@ -637,13 +841,12 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|
|
|
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
|
|
|
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
|
|
|
|
if (size > HASH_THRESHOLD) {
|
|
|
|
|
rpc_msg_set_tensor_hash_req request;
|
|
|
|
|
request.tensor = rpc_tensor;
|
|
|
|
|
request.offset = offset;
|
|
|
|
|
request.hash = fnv_hash((const uint8_t*)data, size);
|
|
|
|
|
auto request = std::make_shared<rpc_msg_set_tensor_hash_req>();
|
|
|
|
|
request->tensor = rpc_tensor;
|
|
|
|
|
request->offset = offset;
|
|
|
|
|
request->hash = fnv_hash((const uint8_t*)data, size);
|
|
|
|
|
rpc_msg_set_tensor_hash_rsp response;
|
|
|
|
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
ctx->dispatcher->send(RPC_CMD_SET_TENSOR_HASH, request, sizeof(*request), &response, sizeof(response));
|
|
|
|
|
if (response.result) {
|
|
|
|
|
// the server has the same data, no need to send it
|
|
|
|
|
return;
|
|
|
|
|
@ -651,22 +854,21 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|
|
|
|
}
|
|
|
|
|
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
|
|
|
|
|
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
|
|
|
|
std::vector<uint8_t> input(input_size, 0);
|
|
|
|
|
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
|
|
|
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
|
|
|
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
|
|
|
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
uint8_t * input = new uint8_t[input_size]();
|
|
|
|
|
memcpy(input, &rpc_tensor, sizeof(rpc_tensor));
|
|
|
|
|
memcpy(input + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
|
|
|
memcpy(input + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
|
|
|
|
std::shared_ptr<uint8_t> input_ptr(input, std::default_delete<uint8_t[]>());
|
|
|
|
|
ctx->dispatcher->send(RPC_CMD_SET_TENSOR, input_ptr, input_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
|
|
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
|
|
|
rpc_msg_get_tensor_req request;
|
|
|
|
|
request.tensor = serialize_tensor(tensor);
|
|
|
|
|
request.offset = offset;
|
|
|
|
|
request.size = size;
|
|
|
|
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
auto request = std::make_shared<rpc_msg_get_tensor_req>();
|
|
|
|
|
request->tensor = serialize_tensor(tensor);
|
|
|
|
|
request->offset = offset;
|
|
|
|
|
request->size = size;
|
|
|
|
|
ctx->dispatcher->send(RPC_CMD_GET_TENSOR, request, sizeof(*request), data, size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
|
|
|
|
@ -676,16 +878,15 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
|
|
|
|
|
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
|
|
|
|
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
|
|
|
|
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
|
|
|
|
if (src_ctx->sock != dst_ctx->sock) {
|
|
|
|
|
if (src_ctx->dispatcher != dst_ctx->dispatcher) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
|
|
|
rpc_msg_copy_tensor_req request;
|
|
|
|
|
request.src = serialize_tensor(src);
|
|
|
|
|
request.dst = serialize_tensor(dst);
|
|
|
|
|
auto request = std::make_shared<rpc_msg_copy_tensor_req>();
|
|
|
|
|
request->src = serialize_tensor(src);
|
|
|
|
|
request->dst = serialize_tensor(dst);
|
|
|
|
|
rpc_msg_copy_tensor_rsp response;
|
|
|
|
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
ctx->dispatcher->send(RPC_CMD_COPY_TENSOR, request, sizeof(*request), &response, sizeof(response));
|
|
|
|
|
return response.result;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
@ -693,9 +894,10 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
|
|
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
|
|
|
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
|
|
|
|
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
auto request = std::make_shared<rpc_msg_buffer_clear_req>();
|
|
|
|
|
request->remote_ptr = ctx->remote_ptr;
|
|
|
|
|
request->value = value;
|
|
|
|
|
ctx->dispatcher->send(RPC_CMD_BUFFER_CLEAR, request, sizeof(*request));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
|
|
|
|
@ -717,15 +919,17 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
|
|
|
|
|
|
|
|
|
|
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
|
|
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
|
|
|
rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
|
|
|
|
|
auto request = std::make_shared<rpc_msg_alloc_buffer_req>();
|
|
|
|
|
request->device = buft_ctx->device;
|
|
|
|
|
request->size = size;
|
|
|
|
|
rpc_msg_alloc_buffer_rsp response;
|
|
|
|
|
auto sock = get_socket(buft_ctx->endpoint);
|
|
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
|
|
|
|
|
auto dispatcher = get_dispatcher(buft_ctx->endpoint);
|
|
|
|
|
dispatcher->send(RPC_CMD_ALLOC_BUFFER, request, sizeof(*request), &response, sizeof(response));
|
|
|
|
|
if (response.remote_ptr != 0) {
|
|
|
|
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
|
|
|
|
ggml_backend_rpc_buffer_interface,
|
|
|
|
|
new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
|
|
|
|
|
new ggml_backend_rpc_buffer_context{dispatcher, nullptr, response.remote_ptr},
|
|
|
|
|
response.remote_size);
|
|
|
|
|
return buffer;
|
|
|
|
|
} else {
|
|
|
|
|
@ -733,11 +937,11 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
|
|
|
|
|
rpc_msg_get_alignment_req request = {device};
|
|
|
|
|
static size_t get_alignment(const std::shared_ptr<rpc_dispatcher> & dispatcher, uint32_t device) {
|
|
|
|
|
auto request = std::make_shared<rpc_msg_get_alignment_req>();
|
|
|
|
|
request->device = device;
|
|
|
|
|
rpc_msg_get_alignment_rsp response;
|
|
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
dispatcher->send(RPC_CMD_GET_ALIGNMENT, request, sizeof(*request), &response, sizeof(response));
|
|
|
|
|
return response.alignment;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -746,11 +950,11 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
|
|
|
|
|
return buft_ctx->alignment;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
|
|
|
|
|
rpc_msg_get_max_size_req request = {device};
|
|
|
|
|
static size_t get_max_size(const std::shared_ptr<rpc_dispatcher> & dispatcher, uint32_t device) {
|
|
|
|
|
auto request = std::make_shared<rpc_msg_get_max_size_req>();
|
|
|
|
|
request->device = device;
|
|
|
|
|
rpc_msg_get_max_size_rsp response;
|
|
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
dispatcher->send(RPC_CMD_GET_MAX_SIZE, request, sizeof(*request), &response, sizeof(response));
|
|
|
|
|
return response.max_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -773,23 +977,20 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
|
|
|
|
|
|
|
|
|
|
if (rpc_get) {
|
|
|
|
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
|
|
|
auto sock = get_socket(buft_ctx->endpoint);
|
|
|
|
|
auto dispatcher = get_dispatcher(buft_ctx->endpoint);
|
|
|
|
|
|
|
|
|
|
rpc_msg_get_alloc_size_req request = {
|
|
|
|
|
/*.device =*/ buft_ctx->device,
|
|
|
|
|
/*.tensor =*/ serialize_tensor(tensor),
|
|
|
|
|
/*.srcs =*/ {},
|
|
|
|
|
};
|
|
|
|
|
auto request = std::make_shared<rpc_msg_get_alloc_size_req>();
|
|
|
|
|
request->device = buft_ctx->device;
|
|
|
|
|
request->tensor = serialize_tensor(tensor);
|
|
|
|
|
|
|
|
|
|
// .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
|
|
|
|
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
|
|
|
request.srcs[i] = serialize_tensor(tensor->src[i]);
|
|
|
|
|
request->srcs[i] = serialize_tensor(tensor->src[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: cache the alloc responses to avoid extra RPC calls?
|
|
|
|
|
rpc_msg_get_alloc_size_rsp response;
|
|
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
dispatcher->send(RPC_CMD_GET_ALLOC_SIZE, request, sizeof(*request), &response, sizeof(response));
|
|
|
|
|
|
|
|
|
|
return response.alloc_size;
|
|
|
|
|
}
|
|
|
|
|
@ -818,9 +1019,44 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
|
|
|
|
delete backend;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
|
|
|
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
|
|
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
|
|
|
|
if (size > HASH_THRESHOLD) {
|
|
|
|
|
auto request = std::make_shared<rpc_msg_set_tensor_hash_req>();
|
|
|
|
|
request->tensor = rpc_tensor;
|
|
|
|
|
request->offset = offset;
|
|
|
|
|
request->hash = fnv_hash((const uint8_t*)data, size);
|
|
|
|
|
rpc_msg_set_tensor_hash_rsp response;
|
|
|
|
|
// TODO: make this async
|
|
|
|
|
ctx->dispatcher->send(RPC_CMD_SET_TENSOR_HASH, request, sizeof(*request), &response, sizeof(response));
|
|
|
|
|
if (response.result) {
|
|
|
|
|
// the server has the same data, no need to send it
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
|
|
|
|
|
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
|
|
|
|
uint8_t * input = new uint8_t[input_size]();
|
|
|
|
|
memcpy(input, &rpc_tensor, sizeof(rpc_tensor));
|
|
|
|
|
memcpy(input + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
|
|
|
memcpy(input + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
|
|
|
|
std::shared_ptr<uint8_t> input_ptr(input, std::default_delete<uint8_t[]>());
|
|
|
|
|
ctx->dispatcher->send_async(RPC_CMD_SET_TENSOR, input_ptr, input_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
|
|
|
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
|
|
auto request = std::make_shared<rpc_msg_get_tensor_req>();
|
|
|
|
|
request->tensor = serialize_tensor(tensor);
|
|
|
|
|
request->offset = offset;
|
|
|
|
|
request->size = size;
|
|
|
|
|
ctx->dispatcher->send_async(RPC_CMD_GET_TENSOR, request, sizeof(*request), data, size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
|
|
|
|
GGML_UNUSED(backend);
|
|
|
|
|
// this is no-op because we don't have any async operations
|
|
|
|
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
|
|
rpc_ctx->dispatcher->synchronize();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
|
|
|
|
|
@ -838,7 +1074,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
|
|
|
|
|
tensors.push_back(serialize_tensor(tensor));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
|
|
|
|
static uint8_t * serialize_graph(uint32_t device, const ggml_cgraph * cgraph, size_t * output_size) {
|
|
|
|
|
uint32_t n_nodes = cgraph->n_nodes;
|
|
|
|
|
std::vector<rpc_tensor> tensors;
|
|
|
|
|
std::unordered_set<ggml_tensor*> visited;
|
|
|
|
|
@ -848,9 +1084,9 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
|
|
|
|
|
// serialization format:
|
|
|
|
|
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
|
|
|
uint32_t n_tensors = tensors.size();
|
|
|
|
|
int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
|
|
|
|
output.resize(output_size, 0);
|
|
|
|
|
uint8_t * dest = output.data();
|
|
|
|
|
*output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
|
|
|
|
uint8_t * output = new uint8_t[*output_size]();
|
|
|
|
|
uint8_t * dest = output;
|
|
|
|
|
memcpy(dest, &device, sizeof(device));
|
|
|
|
|
dest += sizeof(device);
|
|
|
|
|
memcpy(dest, &n_nodes, sizeof(n_nodes));
|
|
|
|
|
@ -863,6 +1099,7 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
|
|
|
|
|
dest += sizeof(n_tensors);
|
|
|
|
|
rpc_tensor * out_tensors = (rpc_tensor *)dest;
|
|
|
|
|
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
|
|
|
|
|
return output;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
|
|
|
@ -871,27 +1108,35 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|
|
|
|
GGML_ASSERT(cgraph->n_nodes > 0);
|
|
|
|
|
bool reuse = rpc_ctx->gc.is_cached(cgraph);
|
|
|
|
|
if (reuse) {
|
|
|
|
|
rpc_msg_graph_recompute_req request;
|
|
|
|
|
request.device = rpc_ctx->device;
|
|
|
|
|
auto sock = get_socket(rpc_ctx->endpoint);
|
|
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
auto request = std::make_shared<rpc_msg_graph_recompute_req>();
|
|
|
|
|
request->device = rpc_ctx->device;
|
|
|
|
|
rpc_ctx->dispatcher->send_async(RPC_CMD_GRAPH_RECOMPUTE, request, sizeof(*request));
|
|
|
|
|
} else {
|
|
|
|
|
rpc_ctx->gc.add(cgraph);
|
|
|
|
|
std::vector<uint8_t> input;
|
|
|
|
|
serialize_graph(rpc_ctx->device, cgraph, input);
|
|
|
|
|
auto sock = get_socket(rpc_ctx->endpoint);
|
|
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
size_t input_size = 0;
|
|
|
|
|
uint8_t * input = serialize_graph(rpc_ctx->device, cgraph, &input_size);
|
|
|
|
|
std::shared_ptr<uint8_t> input_ptr(input, std::default_delete<uint8_t[]>());
|
|
|
|
|
rpc_ctx->dispatcher->send_async(RPC_CMD_GRAPH_COMPUTE, input_ptr, input_size);
|
|
|
|
|
}
|
|
|
|
|
return GGML_STATUS_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
|
|
|
|
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
|
|
rpc_ctx->dispatcher->event_record(event);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
|
|
|
|
|
// this is noop for RPC as we have a single stream
|
|
|
|
|
GGML_UNUSED(backend);
|
|
|
|
|
GGML_UNUSED(event);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
|
|
/* .get_name = */ ggml_backend_rpc_name,
|
|
|
|
|
/* .free = */ ggml_backend_rpc_free,
|
|
|
|
|
/* .set_tensor_async = */ NULL,
|
|
|
|
|
/* .get_tensor_async = */ NULL,
|
|
|
|
|
/* .set_tensor_async = */ ggml_backend_rpc_set_tensor_async,
|
|
|
|
|
/* .get_tensor_async = */ ggml_backend_rpc_get_tensor_async,
|
|
|
|
|
/* .cpy_tensor_async = */ NULL,
|
|
|
|
|
/* .synchronize = */ ggml_backend_rpc_synchronize,
|
|
|
|
|
/* .graph_plan_create = */ NULL,
|
|
|
|
|
@ -899,8 +1144,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
|
|
/* .graph_plan_update = */ NULL,
|
|
|
|
|
/* .graph_plan_compute = */ NULL,
|
|
|
|
|
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
|
|
|
|
/* .event_record = */ NULL,
|
|
|
|
|
/* .event_wait = */ NULL,
|
|
|
|
|
/* .event_record = */ ggml_backend_rpc_event_record,
|
|
|
|
|
/* .event_wait = */ ggml_backend_rpc_event_wait,
|
|
|
|
|
/* .graph_optimize = */ NULL,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
@ -914,13 +1159,9 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
|
|
|
|
|
if (it != buft_map.end()) {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
auto sock = get_socket(endpoint);
|
|
|
|
|
if (sock == nullptr) {
|
|
|
|
|
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
size_t alignment = get_alignment(sock, device);
|
|
|
|
|
size_t max_size = get_max_size(sock, device);
|
|
|
|
|
auto dispatcher = get_dispatcher(endpoint);
|
|
|
|
|
size_t alignment = get_alignment(dispatcher, device);
|
|
|
|
|
size_t max_size = get_max_size(dispatcher, device);
|
|
|
|
|
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
|
|
|
|
/* .endpoint = */ endpoint,
|
|
|
|
|
/* .device = */ device,
|
|
|
|
|
@ -940,11 +1181,12 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
|
|
|
|
|
|
|
|
|
|
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
|
|
|
|
|
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
|
|
|
|
|
auto dispatcher = get_dispatcher(endpoint);
|
|
|
|
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
|
|
|
|
/* .endpoint = */ endpoint,
|
|
|
|
|
/* .device = */ device,
|
|
|
|
|
/* .name = */ dev_name,
|
|
|
|
|
/* .gc = */ {},
|
|
|
|
|
/* .dispatcher = */ dispatcher,
|
|
|
|
|
/* .device = */ device,
|
|
|
|
|
/* .name = */ dev_name,
|
|
|
|
|
/* .gc = */ {},
|
|
|
|
|
};
|
|
|
|
|
auto reg = ggml_backend_rpc_add_server(endpoint);
|
|
|
|
|
ggml_backend_t backend = new ggml_backend {
|
|
|
|
|
@ -960,26 +1202,16 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|
|
|
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
|
|
|
|
|
rpc_msg_get_device_memory_req request;
|
|
|
|
|
request.device = device;
|
|
|
|
|
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
|
|
|
|
|
auto dispatcher = get_dispatcher(endpoint);
|
|
|
|
|
auto request = std::make_shared<rpc_msg_get_device_memory_req>();
|
|
|
|
|
request->device = device;
|
|
|
|
|
rpc_msg_get_device_memory_rsp response;
|
|
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
dispatcher->send(RPC_CMD_GET_DEVICE_MEMORY, request, sizeof(*request), &response, sizeof(response));
|
|
|
|
|
*free = response.free_mem;
|
|
|
|
|
*total = response.total_mem;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
|
|
|
|
|
auto sock = get_socket(endpoint);
|
|
|
|
|
if (sock == nullptr) {
|
|
|
|
|
*free = 0;
|
|
|
|
|
*total = 0;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
get_device_memory(sock, device, free, total);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RPC server-side implementation
|
|
|
|
|
|
|
|
|
|
class rpc_server {
|
|
|
|
|
@ -1701,9 +1933,6 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|
|
|
|
if (!server.free_buffer(request)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (!send_msg(sockfd, nullptr, 0)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case RPC_CMD_BUFFER_CLEAR: {
|
|
|
|
|
@ -1714,9 +1943,6 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|
|
|
|
if (!server.buffer_clear(request)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (!send_msg(sockfd, nullptr, 0)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case RPC_CMD_SET_TENSOR: {
|
|
|
|
|
@ -1751,9 +1977,6 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|
|
|
|
if (!server.init_tensor(request)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (!send_msg(sockfd, nullptr, 0)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case RPC_CMD_GET_TENSOR: {
|
|
|
|
|
@ -1941,10 +2164,10 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
|
|
|
|
|
props->type = ggml_backend_rpc_device_get_type(dev);
|
|
|
|
|
ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
|
|
|
props->caps = {
|
|
|
|
|
/* .async = */ false,
|
|
|
|
|
/* .async = */ true,
|
|
|
|
|
/* .host_buffer = */ false,
|
|
|
|
|
/* .buffer_from_host_ptr = */ false,
|
|
|
|
|
/* .events = */ false,
|
|
|
|
|
/* .events = */ true,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -1980,6 +2203,24 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
|
|
|
|
|
return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ggml_backend_event_t ggml_backend_rpc_device_event_new(ggml_backend_dev_t dev) {
|
|
|
|
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
|
|
|
auto dispatcher = get_dispatcher(ctx->endpoint);
|
|
|
|
|
return dispatcher->event_new(dev);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
|
|
|
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
|
|
|
auto dispatcher = get_dispatcher(ctx->endpoint);
|
|
|
|
|
dispatcher->event_free(event);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ggml_backend_rpc_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
|
|
|
|
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
|
|
|
auto dispatcher = get_dispatcher(ctx->endpoint);
|
|
|
|
|
dispatcher->event_synchronize(event);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
|
|
|
|
/* .get_name = */ ggml_backend_rpc_device_get_name,
|
|
|
|
|
/* .get_description = */ ggml_backend_rpc_device_get_description,
|
|
|
|
|
@ -1993,9 +2234,9 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
|
|
|
|
/* .supports_op = */ ggml_backend_rpc_device_supports_op,
|
|
|
|
|
/* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
|
|
|
|
|
/* .offload_op = */ NULL,
|
|
|
|
|
/* .event_new = */ NULL,
|
|
|
|
|
/* .event_free = */ NULL,
|
|
|
|
|
/* .event_synchronize = */ NULL,
|
|
|
|
|
/* .event_new = */ ggml_backend_rpc_device_event_new,
|
|
|
|
|
/* .event_free = */ ggml_backend_rpc_device_event_free,
|
|
|
|
|
/* .event_synchronize = */ ggml_backend_rpc_device_event_synchronize,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// backend reg interface
|
|
|
|
|
@ -2055,14 +2296,9 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
|
|
|
|
|
auto sock = get_socket(endpoint);
|
|
|
|
|
if (sock == nullptr) {
|
|
|
|
|
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
auto dispatcher = get_dispatcher(endpoint);
|
|
|
|
|
rpc_msg_device_count_rsp response;
|
|
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
|
|
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
dispatcher->send(RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
|
|
|
|
|
return response.device_count;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|