take buffer memory types into account
This commit is contained in:
parent
57a53944a0
commit
51682440b0
|
|
@ -2377,7 +2377,7 @@ static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDe
|
|||
}
|
||||
|
||||
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list,
|
||||
void *import_ptr = nullptr, uint32_t import_memory_type = ~0u) {
|
||||
void *import_ptr = nullptr) {
|
||||
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
|
||||
if (size > device->max_buffer_size) {
|
||||
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
|
||||
|
|
@ -2427,13 +2427,46 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
}
|
||||
|
||||
if (import_ptr) {
|
||||
buf->memory_property_flags = mem_props.memoryTypes[import_memory_type].propertyFlags;
|
||||
vk::MemoryHostPointerPropertiesEXT host_pointer_props;
|
||||
try {
|
||||
host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);
|
||||
} catch (vk::SystemError& e) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
return {};
|
||||
}
|
||||
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
||||
|
||||
uint32_t memory_type_idx;
|
||||
vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
|
||||
for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
|
||||
if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
|
||||
continue;
|
||||
}
|
||||
if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
|
||||
// check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
|
||||
if ((memory_type.propertyFlags & property_flags) == property_flags) {
|
||||
property_flags = memory_type.propertyFlags;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (memory_type_idx == 32) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n");
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
return {};
|
||||
}
|
||||
|
||||
buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
|
||||
try {
|
||||
vk::ImportMemoryHostPointerInfoEXT import_info;
|
||||
import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
|
||||
import_info.pHostPointer = import_ptr;
|
||||
import_info.setPNext(&mem_flags_info);
|
||||
buf->device_memory = device->device.allocateMemory({ size, import_memory_type, &import_info });
|
||||
buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
|
||||
} catch (const vk::SystemError& e) {
|
||||
}
|
||||
} else {
|
||||
|
|
@ -14666,36 +14699,11 @@ static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, si
|
|||
return {};
|
||||
}
|
||||
|
||||
vk::MemoryHostPointerPropertiesEXT host_pointer_props;
|
||||
try {
|
||||
host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, ptr);
|
||||
} catch (vk::SystemError& e) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
|
||||
return {};
|
||||
}
|
||||
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
||||
|
||||
uint32_t memory_type_idx;
|
||||
vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
|
||||
for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
|
||||
if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
|
||||
// check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
|
||||
if ((memory_type.propertyFlags & property_flags) == property_flags) {
|
||||
property_flags = memory_type.propertyFlags;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (memory_type_idx == 32) {
|
||||
return {};
|
||||
}
|
||||
const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
|
||||
|
||||
vk_buffer buf {};
|
||||
try {
|
||||
buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr, memory_type_idx);
|
||||
buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
|
||||
} catch (vk::SystemError& e) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue