vulkan: support buffer_from_host_ptr
This commit is contained in:
parent
cd78e57c3a
commit
43b7e1ed9e
|
|
@ -523,6 +523,8 @@ struct vk_device_struct {
|
|||
uint64_t max_memory_allocation_size;
|
||||
uint64_t max_buffer_size;
|
||||
uint64_t suballocation_block_size;
|
||||
uint64_t min_imported_host_pointer_alignment;
|
||||
bool external_memory_host {};
|
||||
bool fp16;
|
||||
bool bf16;
|
||||
bool pipeline_robustness;
|
||||
|
|
@ -2373,7 +2375,8 @@ static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDe
|
|||
return indices;
|
||||
}
|
||||
|
||||
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) {
|
||||
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list,
|
||||
void *import_ptr = nullptr, uint32_t import_memory_type = ~0u) {
|
||||
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
|
||||
if (size > device->max_buffer_size) {
|
||||
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
|
||||
|
|
@ -2416,35 +2419,47 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
mem_flags_info.setPNext(&mem_priority_info);
|
||||
}
|
||||
|
||||
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
||||
const auto & req_flags = *it;
|
||||
|
||||
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
|
||||
|
||||
if (memory_type_indices.empty()) {
|
||||
continue;
|
||||
if (import_ptr) {
|
||||
buf->memory_property_flags = mem_props.memoryTypes[import_memory_type].propertyFlags;
|
||||
try {
|
||||
vk::ImportMemoryHostPointerInfoEXT import_info;
|
||||
import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
|
||||
import_info.pHostPointer = import_ptr;
|
||||
import_info.setPNext(&mem_flags_info);
|
||||
buf->device_memory = device->device.allocateMemory({ size, import_memory_type, &import_info });
|
||||
} catch (const vk::SystemError& e) {
|
||||
}
|
||||
buf->memory_property_flags = req_flags;
|
||||
} else {
|
||||
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
||||
const auto & req_flags = *it;
|
||||
|
||||
bool done = false;
|
||||
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
|
||||
|
||||
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
|
||||
try {
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
|
||||
done = true;
|
||||
break;
|
||||
} catch (const vk::SystemError& e) {
|
||||
// loop and retry
|
||||
// during last attempt throw the exception
|
||||
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
throw e;
|
||||
if (memory_type_indices.empty()) {
|
||||
continue;
|
||||
}
|
||||
buf->memory_property_flags = req_flags;
|
||||
|
||||
bool done = false;
|
||||
|
||||
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
|
||||
try {
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
|
||||
done = true;
|
||||
break;
|
||||
} catch (const vk::SystemError& e) {
|
||||
// loop and retry
|
||||
// during last attempt throw the exception
|
||||
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2455,8 +2470,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
|
||||
buf->ptr = nullptr;
|
||||
|
||||
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
|
||||
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
|
||||
if (import_ptr) {
|
||||
buf->ptr = import_ptr;
|
||||
} else {
|
||||
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
|
||||
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
|
||||
|
|
@ -4397,6 +4416,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
|
||||
getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
|
||||
device->memory_priority = true;
|
||||
} else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
|
||||
device->external_memory_host = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4411,6 +4432,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
vk::PhysicalDeviceVulkan12Properties vk12_props;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
|
||||
vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;
|
||||
|
||||
props2.pNext = &props3;
|
||||
props3.pNext = &subgroup_props;
|
||||
|
|
@ -4450,6 +4472,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
||||
}
|
||||
|
||||
if (device->external_memory_host) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;
|
||||
last_struct = (VkBaseOutStructure *)&external_memory_host_props;
|
||||
}
|
||||
|
||||
device->physical_device.getProperties2(&props2);
|
||||
device->properties = props2.properties;
|
||||
device->vendor_id = device->properties.vendorID;
|
||||
|
|
@ -4536,6 +4563,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
|
||||
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
||||
|
||||
device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;
|
||||
|
||||
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
|
||||
|
||||
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
||||
|
|
@ -4667,6 +4696,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device_extensions.push_back("VK_KHR_pipeline_executable_properties");
|
||||
}
|
||||
|
||||
if (device->external_memory_host) {
|
||||
device_extensions.push_back("VK_EXT_external_memory_host");
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
||||
|
||||
device->pipeline_executable_properties_support = pipeline_executable_properties_support;
|
||||
|
|
@ -14007,10 +14040,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
|
|||
props->type = ggml_backend_vk_device_get_type(dev);
|
||||
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
|
||||
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
|
||||
props->caps = {
|
||||
/* .async = */ true,
|
||||
/* .host_buffer = */ true,
|
||||
/* .buffer_from_host_ptr = */ false,
|
||||
/* .buffer_from_host_ptr = */ device->external_memory_host,
|
||||
/* .events = */ true,
|
||||
};
|
||||
}
|
||||
|
|
@ -14589,6 +14625,70 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
|
|||
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")");
|
||||
GGML_UNUSED(max_tensor_size);
|
||||
|
||||
ggml_backend_buffer_t ret {};
|
||||
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
|
||||
if (!device->external_memory_host) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
uintptr_t uptr = reinterpret_cast<uintptr_t>(ptr);
|
||||
if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
vk::MemoryHostPointerPropertiesEXT host_pointer_props;
|
||||
try {
|
||||
host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, ptr);
|
||||
} catch (vk::SystemError& e) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
|
||||
return ret;
|
||||
}
|
||||
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
||||
|
||||
uint32_t memory_type_idx;
|
||||
vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
|
||||
for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
|
||||
if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
|
||||
// check for visible+coherent+cache. Other flags (e.g. devicelocal) are allowed
|
||||
if ((memory_type.propertyFlags & property_flags) == property_flags) {
|
||||
property_flags = memory_type.propertyFlags;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (memory_type_idx == 32) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
vk_buffer buf {};
|
||||
try {
|
||||
buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr, memory_type_idx);
|
||||
} catch (vk::SystemError& e) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (!buf) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);
|
||||
|
||||
ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
|
||||
/* .get_name = */ ggml_backend_vk_device_get_name,
|
||||
/* .get_description = */ ggml_backend_vk_device_get_description,
|
||||
|
|
@ -14598,7 +14698,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
|
|||
/* .init_backend = */ ggml_backend_vk_device_init,
|
||||
/* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
|
||||
/* .buffer_from_host_ptr = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr,
|
||||
/* .supports_op = */ ggml_backend_vk_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_vk_device_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_vk_device_offload_op,
|
||||
|
|
|
|||
Loading…
Reference in New Issue