diff --git a/CMakeLists.txt b/CMakeLists.txt index 3858968..e5ab191 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,6 +68,8 @@ target_link_libraries(gemma hwy hwy_contrib sentencepiece) target_include_directories(gemma PRIVATE ./) FetchContent_GetProperties(sentencepiece) target_include_directories(gemma PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(gemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(gemma PRIVATE $<$:-Wno-deprecated-declarations>) ## Library Target @@ -77,3 +79,5 @@ set_target_properties(libgemma PROPERTIES PREFIX "") target_include_directories(libgemma PUBLIC ./) target_link_libraries(libgemma hwy hwy_contrib sentencepiece) target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 8d6c1d0..550c727 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -16,11 +16,16 @@ // copybara:import_next_line:gemma_cpp #include "compression/blob_store.h" -#include // open #include #include // SEEK_END - unistd isn't enough for IDE. #include // O_RDONLY -#include // read, close +#include // open +#if HWY_OS_WIN +#include // read, write, close +#include +#else +#include // read, write, close +#endif #include #include @@ -30,6 +35,54 @@ #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_compiler_arch.h" +namespace { +#if HWY_OS_WIN + +// pread is not supported on Windows +static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) { + HANDLE file = reinterpret_cast(_get_osfhandle(fd)); + if (file == INVALID_HANDLE_VALUE) { + return -1; + } + + OVERLAPPED overlapped = {0}; + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + + DWORD bytes_read; + if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) { + if (GetLastError() != ERROR_HANDLE_EOF) { + return -1; + } + } + + return bytes_read; +} + +// pwrite is not supported on Windows +static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) { + HANDLE file = reinterpret_cast(_get_osfhandle(fd)); + if (file == INVALID_HANDLE_VALUE) { + return -1; + } + + OVERLAPPED overlapped = {0}; + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + + DWORD bytes_written; + if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) { + if (GetLastError() != ERROR_HANDLE_EOF) { + return -1; + } + } + + return bytes_written; +} + +#endif +} + namespace gcpp { hwy::uint128_t MakeKey(const char* string) { @@ -64,19 +117,30 @@ static void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data, } } + struct IO { // Returns size in bytes or 0. static uint64_t FileSize(const char* filename) { int fd = open(filename, O_RDONLY); - if (fd >= 0) { - const off_t size = lseek(fd, 0, SEEK_END); - HWY_ASSERT(close(fd) != -1); - if (size != static_cast(-1)) { - return static_cast(size); - } + if (fd < 0) { + return 0; } - return 0; +#if HWY_OS_WIN + const int64_t size = _lseeki64(fd, 0, SEEK_END); + HWY_ASSERT(close(fd) != -1); + if (size < 0) { + return 0; + } +#else + const off_t size = lseek(fd, 0, SEEK_END); + HWY_ASSERT(close(fd) != -1); + if (size == static_cast(-1)) { + return 0; + } +#endif + + return static_cast(size); } static bool Read(int fd, uint64_t offset, uint64_t size, void* to) { @@ -252,7 +316,14 @@ class BlobStore { #pragma pack(pop) BlobError BlobReader::Open(const char* filename) { +#if HWY_OS_WIN + DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN; + HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, flags, nullptr); + if (file == INVALID_HANDLE_VALUE) return __LINE__; + fd_ = _open_osfhandle(reinterpret_cast(file), _O_RDONLY); +#else fd_ = open(filename, O_RDONLY); +#endif if (fd_ < 0) return __LINE__; #if _POSIX_C_SOURCE >= 200112L @@ -330,7 +401,14 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, keys_.data(), blobs_.data(), keys_.size()); // Create/replace existing file. +#if HWY_OS_WIN + DWORD flags = FILE_ATTRIBUTE_NORMAL; + HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, flags, nullptr); + if (file == INVALID_HANDLE_VALUE) return __LINE__; + const int fd = _open_osfhandle(reinterpret_cast(file), _O_WRONLY); +#else const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644); +#endif if (fd < 0) return __LINE__; std::atomic_flag err = ATOMIC_FLAG_INIT; @@ -341,6 +419,7 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, err.test_and_set(); } }); + HWY_ASSERT(close(fd) != -1); if (err.test_and_set()) return __LINE__; return 0; } diff --git a/util/app.h b/util/app.h index 966fa41..bd665a4 100644 --- a/util/app.h +++ b/util/app.h @@ -18,7 +18,9 @@ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ +#if HWY_OS_LINUX #include +#endif #include #include // std::clamp