diff --git a/BUILD.bazel b/BUILD.bazel index a504f2f..8f17a6a 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -55,6 +55,7 @@ cc_library( name = "args", hdrs = ["util/args.h"], deps = [ + "//compression:io", "@hwy//:hwy", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index f1c7104..fed1619 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,8 @@ set(SOURCES compression/blob_store.h compression/compress.h compression/compress-inl.h + compression/io.cc + compression/io.h compression/nuq.h compression/nuq-inl.h compression/sfp.h diff --git a/DEVELOPERS.md b/DEVELOPERS.md index 4e104b9..08565ad 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -169,9 +169,9 @@ inference path of the Gemma model. The sentencepiece library we depend on requires some additional work to build with the Bazel build system. First, it does not export its BUILD file, so we provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of -the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to -support Abseil as a standalone dependency without third_party/ prefixes, similar -to the transforms we apply to Gemma via Copybara. +the Abseil library. `bazel/sentencepiece.patch` changes the code to support +Abseil as a standalone dependency without third_party/ prefixes, similar to the +transforms we apply to Gemma via Copybara. ## Discord diff --git a/MODULE.bazel b/MODULE.bazel index d183e89..72d1d42 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -33,7 +33,7 @@ http_archive( strip_prefix = "sentencepiece-0.1.96", urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"], build_file = "@//bazel:sentencepiece.bazel", - patches = ["@//bazel:com_google_sentencepiece.patch"], + patches = ["@//bazel:sentencepiece.patch"], patch_args = ["-p1"], ) diff --git a/bazel/BUILD b/bazel/BUILD index 194a082..7964848 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -1,4 +1,4 @@ -# Required for referencing bazel:com_google_sentencepiece.patch +# Required for referencing bazel:sentencepiece.patch package( default_applicable_licenses = ["//:license"], default_visibility = ["//:__subpackages__"], diff --git a/bazel/com_google_sentencepiece.patch b/bazel/sentencepiece.patch similarity index 100% rename from bazel/com_google_sentencepiece.patch rename to bazel/sentencepiece.patch diff --git a/compression/BUILD b/compression/BUILD index b0c8431..ecf067a 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -10,11 +10,23 @@ package( ], ) +cc_library( + name = "io", + srcs = ["io.cc"], + hdrs = ["io.h"], + # Placeholder for io textual_hdrs, do not remove + deps = [ + # Placeholder for io deps, do not remove + "@hwy//:hwy", + ], +) + cc_library( name = "blob_store", srcs = ["blob_store.cc"], hdrs = ["blob_store.h"], deps = [ + ":io", "@hwy//:hwy", "@hwy//:thread_pool", ], diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 7f03833..6dea7df 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -13,88 +13,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Request POSIX 2008, including `pread()` and `posix_fadvise()`. -#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700 -#undef _XOPEN_SOURCE -#define _XOPEN_SOURCE 700 -#endif -#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809 -#define _POSIX_C_SOURCE 200809 -#endif - -// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c. -#undef _FILE_OFFSET_BITS -#define _FILE_OFFSET_BITS 64 - #include "compression/blob_store.h" -#include // open +#include #include -#include // SEEK_END - unistd isn't enough for IDE. -#include // O_RDONLY -#if HWY_OS_WIN -#include -#include // read, write, close -#else -#include // read, write, close -#endif #include #include +#include "compression/io.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_compiler_arch.h" -namespace { -#if HWY_OS_WIN - -// pread is not supported on Windows -static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) { - HANDLE file = reinterpret_cast(_get_osfhandle(fd)); - if (file == INVALID_HANDLE_VALUE) { - return -1; - } - - OVERLAPPED overlapped = {0}; - overlapped.Offset = offset & 0xFFFFFFFF; - overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; - - DWORD bytes_read; - if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) { - if (GetLastError() != ERROR_HANDLE_EOF) { - return -1; - } - } - - return bytes_read; -} - -// pwrite is not supported on Windows -static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) { - HANDLE file = reinterpret_cast(_get_osfhandle(fd)); - if (file == INVALID_HANDLE_VALUE) { - return -1; - } - - OVERLAPPED overlapped = {0}; - overlapped.Offset = offset & 0xFFFFFFFF; - overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; - - DWORD bytes_written; - if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) { - if (GetLastError() != ERROR_HANDLE_EOF) { - return -1; - } - } - - return bytes_written; -} - -#endif -} // namespace - namespace gcpp { hwy::uint128_t MakeKey(const char* string) { @@ -131,61 +63,6 @@ void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data, } } // namespace -struct IO { - // Returns size in bytes or 0. - static uint64_t FileSize(const char* filename) { - int fd = open(filename, O_RDONLY); - if (fd < 0) { - return 0; - } - -#if HWY_OS_WIN - const int64_t size = _lseeki64(fd, 0, SEEK_END); - HWY_ASSERT(close(fd) != -1); - if (size < 0) { - return 0; - } -#else - static_assert(sizeof(off_t) == 8, "64-bit off_t required"); - const off_t size = lseek(fd, 0, SEEK_END); - HWY_ASSERT(close(fd) != -1); - if (size == static_cast(-1)) { - return 0; - } -#endif - - return static_cast(size); - } - - static bool Read(int fd, uint64_t offset, uint64_t size, void* to) { - uint8_t* bytes = reinterpret_cast(to); - uint64_t pos = 0; - for (;;) { - // pread seems to be faster than lseek + read when parallelized. - const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos); - if (bytes_read <= 0) break; - pos += bytes_read; - HWY_ASSERT(pos <= size); - if (pos == size) break; - } - return pos == size; // success if managed to read desired size - } - - static bool Write(const void* from, uint64_t size, uint64_t offset, int fd) { - const uint8_t* bytes = reinterpret_cast(from); - uint64_t pos = 0; - for (;;) { - const auto bytes_written = - pwrite(fd, bytes + pos, size - pos, offset + pos); - if (bytes_written <= 0) break; - pos += bytes_written; - HWY_ASSERT(pos <= size); - if (pos == size) break; - } - return pos == size; // success if managed to write desired size - } -}; // IO - static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian"); // On-disk representation (little-endian). @@ -323,25 +200,11 @@ class BlobStore { #pragma pack(pop) BlobError BlobReader::Open(const char* filename) { -#if HWY_OS_WIN - DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN; - HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr, - OPEN_EXISTING, flags, nullptr); - if (file == INVALID_HANDLE_VALUE) return __LINE__; - fd_ = _open_osfhandle(reinterpret_cast(file), _O_RDONLY); -#else - fd_ = open(filename, O_RDONLY); -#endif - if (fd_ < 0) return __LINE__; - -#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21) - // Doubles the readahead window, which seems slightly faster when cached. - (void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL); -#endif + if (!file_.Open(filename, "r")) return __LINE__; // Read first part of header to get actual size. BlobStore bs; - if (!IO::Read(fd_, 0, sizeof(bs), &bs)) return __LINE__; + if (!file_.Read(0, sizeof(bs), &bs)) return __LINE__; const size_t padded_size = bs.PaddedHeaderSize(); HWY_ASSERT(padded_size >= sizeof(bs)); @@ -353,18 +216,11 @@ BlobError BlobReader::Open(const char* filename) { hwy::CopySameSize(&bs, blob_store_.get()); // Read the rest of the header, but not the full file. uint8_t* bytes = reinterpret_cast(blob_store_.get()); - if (!IO::Read(fd_, sizeof(bs), padded_size - sizeof(bs), - bytes + sizeof(bs))) { + if (!file_.Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) { return __LINE__; } - return blob_store_->CheckValidity(IO::FileSize(filename)); -} - -BlobReader::~BlobReader() { - if (fd_ >= 0) { - HWY_ASSERT(close(fd_) != -1); - } + return blob_store_->CheckValidity(file_.FileSize()); } BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { @@ -391,14 +247,14 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { // between consecutive runs. // - memory-mapped I/O is less predictable and adds noise to measurements. BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { - const int fd = fd_; + File* pfile = &file_; // not owned const auto& requests = requests_; std::atomic_flag err = ATOMIC_FLAG_INIT; // >5x speedup from parallel reads when cached. pool.Run(0, requests.size(), - [fd, &requests, &err](uint64_t i, size_t /*thread*/) { - if (!IO::Read(fd, requests[i].offset, requests[i].size, - requests[i].data)) { + [pfile, &requests, &err](uint64_t i, size_t /*thread*/) { + if (!pfile->Read(requests[i].offset, requests[i].size, + requests[i].data)) { err.test_and_set(); } }); @@ -406,8 +262,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { return 0; } -BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, - const char* filename) const { +BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) { HWY_ASSERT(keys_.size() == blobs_.size()); // Concatenate blobs in memory. @@ -418,26 +273,18 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, keys_.data(), blobs_.data(), keys_.size(), bs.get()); // Create/replace existing file. -#if HWY_OS_WIN - DWORD flags = FILE_ATTRIBUTE_NORMAL; - HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, - flags, nullptr); - if (file == INVALID_HANDLE_VALUE) return __LINE__; - const int fd = _open_osfhandle(reinterpret_cast(file), _O_WRONLY); -#else - const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644); -#endif - if (fd < 0) return __LINE__; + File file; + if (!file.Open(filename, "w+")) return __LINE__; + File* pfile = &file; // not owned std::atomic_flag err = ATOMIC_FLAG_INIT; pool.Run(0, requests.size(), - [fd, &requests, &err](uint64_t i, size_t /*thread*/) { - if (!IO::Write(requests[i].data, requests[i].size, - requests[i].offset, fd)) { + [pfile, &requests, &err](uint64_t i, size_t /*thread*/) { + if (!pfile->Write(requests[i].data, requests[i].size, + requests[i].offset)) { err.test_and_set(); } }); - HWY_ASSERT(close(fd) != -1); if (err.test_and_set()) return __LINE__; return 0; } diff --git a/compression/blob_store.h b/compression/blob_store.h index 8736d0f..8e712c2 100644 --- a/compression/blob_store.h +++ b/compression/blob_store.h @@ -21,6 +21,7 @@ #include +#include "compression/io.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::uint128_t #include "hwy/contrib/thread_pool/thread_pool.h" @@ -59,7 +60,7 @@ struct BlobIO { class BlobReader { public: BlobReader() { requests_.reserve(500); } - ~BlobReader(); + ~BlobReader() = default; // Opens `filename` and reads its header. BlobError Open(const char* filename); @@ -73,7 +74,7 @@ class BlobReader { private: BlobStorePtr blob_store_; // holds header, not the entire file std::vector requests_; - int fd_ = 0; + File file_; }; class BlobWriter { @@ -84,7 +85,7 @@ class BlobWriter { } // Stores all blobs to disk in the given order with padding for alignment. - BlobError WriteAll(hwy::ThreadPool& pool, const char* filename) const; + BlobError WriteAll(hwy::ThreadPool& pool, const char* filename); private: std::vector keys_; diff --git a/compression/io.cc b/compression/io.cc new file mode 100644 index 0000000..4b32f6b --- /dev/null +++ b/compression/io.cc @@ -0,0 +1,188 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Safe to be first, does not include POSIX headers. +#include "compression/io.h" + +// 1.5x slowdown vs. POSIX (200 ms longer startup), hence opt-in. +#ifdef GEMMA_IO_GOOGLE +#include "compression/io_google.cc" +#else + +// Request POSIX 2008, including `pread()` and `posix_fadvise()`. +#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700 +#undef _XOPEN_SOURCE +#define _XOPEN_SOURCE 700 +#endif +#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809 +#define _POSIX_C_SOURCE 200809 +#endif + +// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c. +#undef _FILE_OFFSET_BITS +#define _FILE_OFFSET_BITS 64 + +#include // open +#include +#include +#include // SEEK_END - unistd isn't enough for IDE. +#include // O_RDONLY + +#include "hwy/base.h" // HWY_ASSERT +#include "hwy/detect_compiler_arch.h" +#if HWY_OS_WIN +#include +#include // read, write, close +#else +#include // read, write, close +#endif + +namespace gcpp { + +// Emulate missing POSIX functions. +#if HWY_OS_WIN +namespace { + +static inline int open(const char* filename, int flags, int mode = 0) { + const bool is_read = (flags & _O_RDONLY) != 0; + const DWORD win_flags = + FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0); + const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE; + const DWORD share = is_read ? FILE_SHARE_READ : 0; + const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS; + const HANDLE file = + CreateFileA(filename, access, share, nullptr, create, win_flags, nullptr); + if (file == INVALID_HANDLE_VALUE) return -1; + return _open_osfhandle(reinterpret_cast(file), flags); +} + +static inline off_t lseek(int fd, off_t offset, int whence) { + return _lseeki64(fd, offset, whence); +} + +static inline int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) { + HANDLE file = reinterpret_cast(_get_osfhandle(fd)); + if (file == INVALID_HANDLE_VALUE) { + return -1; + } + + OVERLAPPED overlapped = {0}; + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + + DWORD bytes_read; + if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) { + if (GetLastError() != ERROR_HANDLE_EOF) { + return -1; + } + } + + return bytes_read; +} + +static inline int64_t pwrite(int fd, const void* buf, uint64_t size, + uint64_t offset) { + HANDLE file = reinterpret_cast(_get_osfhandle(fd)); + if (file == INVALID_HANDLE_VALUE) { + return -1; + } + + OVERLAPPED overlapped = {0}; + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + + DWORD bytes_written; + if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) { + if (GetLastError() != ERROR_HANDLE_EOF) { + return -1; + } + } + + return bytes_written; +} + +} // namespace +#endif // HWY_OS_WIN + +bool File::Open(const char* filename, const char* mode) { + const bool is_read = mode[0] != 'w'; + const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC; + int fd = open(filename, flags, 0644); + if (fd < 0) { + p_ = 0; + return false; + } + + if (is_read) { +#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21) + // Doubles the readahead window, which seems slightly faster when cached. + (void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL); +#endif + } + + p_ = static_cast(fd); + return true; +} + +void File::Close() { + const int fd = static_cast(p_); + if (fd > 0) { + HWY_ASSERT(close(fd) != -1); + p_ = 0; + } +} + +uint64_t File::FileSize() const { + static_assert(sizeof(off_t) == 8, "64-bit off_t required"); + const int fd = static_cast(p_); + const off_t size = lseek(fd, 0, SEEK_END); + if (size < 0) { + return 0; + } + return static_cast(size); +} + +bool File::Read(uint64_t offset, uint64_t size, void* to) const { + const int fd = static_cast(p_); + uint8_t* bytes = reinterpret_cast(to); + uint64_t pos = 0; + for (;;) { + // pread seems to be faster than lseek + read when parallelized. + const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos); + if (bytes_read <= 0) break; + pos += bytes_read; + HWY_ASSERT(pos <= size); + if (pos == size) break; + } + return pos == size; // success if managed to read desired size +} + +bool File::Write(const void* from, uint64_t size, uint64_t offset) { + const int fd = static_cast(p_); + const uint8_t* bytes = reinterpret_cast(from); + uint64_t pos = 0; + for (;;) { + const auto bytes_written = + pwrite(fd, bytes + pos, size - pos, offset + pos); + if (bytes_written <= 0) break; + pos += bytes_written; + HWY_ASSERT(pos <= size); + if (pos == size) break; + } + return pos == size; // success if managed to write desired size +} + +} // namespace gcpp +#endif // GEMMA_IO_GOOGLE diff --git a/compression/io.h b/compression/io.h new file mode 100644 index 0000000..29b429d --- /dev/null +++ b/compression/io.h @@ -0,0 +1,52 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ + +#include + +namespace gcpp { + +// unique_ptr-like interface with RAII, but not (yet) moveable. +class File { + public: + File() = default; + ~File() { Close(); } + File(const File& other) = delete; + const File& operator=(const File& other) = delete; + + // Returns false on failure. `mode` is either "r" or "w+". + bool Open(const char* filename, const char* mode); + + // No effect if `Open` returned false or `Close` already called. + void Close(); + + // Returns size in bytes or 0. + uint64_t FileSize() const; + + // Returns true if all the requested bytes were read. + bool Read(uint64_t offset, uint64_t size, void* to) const; + + // Returns true if all the requested bytes were written. + bool Write(const void* from, uint64_t size, uint64_t offset); + + private: + intptr_t p_ = 0; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ diff --git a/gemma/benchmarks.cc b/gemma/benchmarks.cc deleted file mode 100644 index fb2c6e3..0000000 --- a/gemma/benchmarks.cc +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include - -#include "third_party/benchmark/include/benchmark/benchmark.h" -#include "gemma/gemma.h" -#include "util/app.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -gcpp::LoaderArgs* loader = nullptr; -gcpp::InferenceArgs* inference = nullptr; -gcpp::Gemma* model = nullptr; -hwy::ThreadPool* pool = nullptr; -hwy::ThreadPool* inner_pool = nullptr; - -void run_gemma_prompt(const std::string& prompt_string, - benchmark::State& state) { - std::mt19937 gen; - std::vector prompt; - - if (prompt_string.empty()) return; - HWY_ASSERT(model->Tokenizer().Encode(prompt_string, &prompt).ok()); - - int token_counter = 0; - auto stream_token = [&token_counter](int, float) { - token_counter++; - return true; - }; - - for (auto s : state) { - GenerateGemma( - *model, *inference, prompt, /*start_token=*/0, *pool, *inner_pool, - stream_token, - /*accept=*/[](int) { return true; }, gen, /*verbosity=*/0); - } - - state.SetItemsProcessed(token_counter); -} - -static void BM_short_prompt(benchmark::State& state) { - run_gemma_prompt("What is the capital of Spain? ", state); -} - -static void BM_factuality_prompt(benchmark::State& state) { - run_gemma_prompt("How does an inkjet printer work? ", state); -} - -static void BM_creative_prompt(benchmark::State& state) { - run_gemma_prompt( - "Tell me a story about a magical bunny and their TRS-80. ", - state); -} - -static void BM_coding_prompt(benchmark::State& state) { - run_gemma_prompt( - "Write a python program to generate a fibonacci sequence. ", - state); -} - -static void BM_long_coding_prompt(benchmark::State& state) { - std::ifstream t("benchmarks.cc", std::ios_base::in); - std::stringstream buffer; - buffer << t.rdbuf(); - std::string prompt_string = buffer.str(); - t.close(); - - run_gemma_prompt("Make improvements to the following code:\n " + - prompt_string + " ", - state); -} - -int main(int argc, char** argv) { - loader = new gcpp::LoaderArgs(argc, argv); - inference = new gcpp::InferenceArgs(argc, argv); - gcpp::AppArgs app(argc, argv); - - pool = new ::hwy::ThreadPool(app.num_threads); - inner_pool = new ::hwy::ThreadPool(0); - model = new gcpp::Gemma(*loader, *pool); - - inference->max_tokens = 128; - BENCHMARK(BM_short_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - inference->max_tokens = 256; - BENCHMARK(BM_factuality_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - BENCHMARK(BM_creative_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - BENCHMARK(BM_coding_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - inference->max_tokens = 1024; - BENCHMARK(BM_long_coding_prompt) - ->Iterations(3) - ->Unit(benchmark::kMillisecond) - ->UseRealTime(); - - ::benchmark ::RunSpecifiedBenchmarks(); - ::benchmark ::Shutdown(); - - delete loader; - delete inference; - delete model; - delete pool; - - return 0; -} diff --git a/gemma/compress_weights.cc b/gemma/compress_weights.cc index 776da50..a4d414c 100644 --- a/gemma/compress_weights.cc +++ b/gemma/compress_weights.cc @@ -55,7 +55,7 @@ struct Args : public ArgsBase { return "Missing --compressed_weights flag, a file for the compressed " "model."; } - if (!weights.exists()) { + if (!weights.Exists()) { return "Can't open file specified with --weights flag."; } return nullptr; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index be3fca4..7aaf912 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -40,7 +40,6 @@ #include #include #include -#include // NOLINT #include #include #include @@ -182,7 +181,7 @@ hwy::AlignedFreeUniquePtr LoadWeights( const Path& checkpoint, hwy::ThreadPool& pool, bool scale_for_compression = false) { PROFILER_ZONE("Startup.LoadWeights"); - if (!std::filesystem::exists(checkpoint.path)) { + if (!checkpoint.Exists()) { HWY_ABORT("The model weights file '%s' does not exist.", checkpoint.path.c_str()); } @@ -1318,8 +1317,8 @@ void ForEachTensor(const Weights* weights, template hwy::AlignedFreeUniquePtr LoadCompressedWeights( const Path& weights, hwy::ThreadPool& pool) { - PROFILER_ZONE("Startup.LoadCache"); - if (!std::filesystem::exists(weights.path)) { + PROFILER_ZONE("Startup.LoadCompressedWeights"); + if (!weights.Exists()) { HWY_ABORT("The model weights file '%s' does not exist.", weights.path.c_str()); } @@ -1395,7 +1394,7 @@ template void CompressWeights(const Path& weights_path, const Path& compressed_weights_path, hwy::ThreadPool& pool) { - if (!std::filesystem::exists(weights_path.path)) { + if (!weights_path.Exists()) { HWY_ABORT("The model weights file '%s' does not exist.", weights_path.path.c_str()); } diff --git a/gemma/run_csv.cc b/gemma/run_csv.cc deleted file mode 100644 index cea3826..0000000 --- a/gemma/run_csv.cc +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Command line text interface to gemma. - -#include - -#include -#include -#include -#include - -#include "gemma/configs.h" -#include "gemma/gemma.h" -#include "util/app.h" -#include "util/args.h" // ArgsBase -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" -#include "hwy/profiler.h" -#include "third_party/riegeli/bytes/file_reader.h" -#include "third_party/riegeli/bytes/file_writer.h" -#include "third_party/riegeli/csv/csv_reader.h" -#include "third_party/riegeli/csv/csv_writer.h" - -namespace gcpp { - -struct CsvArgs : public ArgsBase { - CsvArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - - Path input_csv; - Path output_csv; - int prompt_column; - - template - void ForEach(const Visitor& visitor) { - visitor(input_csv, "input_csv", Path(), - "When set, prompts will be read from this CSV."); - visitor(output_csv, "output_csv", Path("/tmp/output.csv"), - "When --input_csv is set, prompts will be written to this CSV."); - visitor(prompt_column, "prompt_column", 0, "Prompt column index"); - }; -}; - -void FileGemma(gcpp::Gemma& model, InferenceArgs& inference, AppArgs& app, - CsvArgs& csv, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const gcpp::AcceptFunc& accept_token) { - int abs_pos = 0; // absolute token index over all turns - int current_pos = 0; // token index within the current turn - int prompt_size{}; - - std::mt19937 gen; - if (inference.deterministic) { - gen.seed(42); - } else { - std::random_device rd; - gen.seed(rd()); - } - - std::stringstream response_stream; - - // callback function invoked for each generated token. - auto stream_token = [&inference, &abs_pos, ¤t_pos, &gen, &prompt_size, - tokenizer = &model.Tokenizer(), - &response_stream](int token, float) { - ++abs_pos; - ++current_pos; - if (current_pos < prompt_size) { - // pass - } else if (token == gcpp::EOS_ID) { - if (!inference.multiturn) { - abs_pos = 0; - if (inference.deterministic) { - gen.seed(42); - } - } - // end of stream - } else { - std::string token_text; - HWY_ASSERT(tokenizer->Decode({token}, &token_text).ok()); - // +1 since position is incremented above - if (current_pos == prompt_size + 1) { - // first token of response - token_text.erase(0, token_text.find_first_not_of(" \t\n")); - } - if (token_text != "\n") - response_stream << token_text; - else - response_stream << "\\n"; - } - return true; - }; - - riegeli::CsvReader csv_reader( - riegeli::FileReader(csv.input_csv.path), - riegeli::CsvReaderBase::Options().set_comment('#').set_recovery( - [](absl::Status status, riegeli::CsvReaderBase& csv_reader) { - fprintf(stderr, "Invalid entry: %s", status.message().data()); - return true; - })); - - riegeli::CsvWriter csv_writer( - riegeli::FileWriter(csv.output_csv.path), - riegeli::CsvWriterBase::Options().set_header({"prompt", "response"})); - - if (!csv_reader.ok()) { - HWY_ABORT("Invalid input CSV path %s", csv.input_csv.path.c_str()); - } - - if (!csv_writer.ok()) { - HWY_ABORT("Invalid output CSV path %s", csv.output_csv.path.c_str()); - } - - while (abs_pos < inference.max_tokens) { - std::string prompt_string; - std::vector prompt; - current_pos = 0; - - std::vector record; - csv_reader.ReadRecord(record); - - if (record.empty()) { - break; - } - - prompt_string = record[csv.prompt_column]; - fprintf(stdout, "Prompt: %s\n", prompt_string.c_str()); - - prompt_string = - "user\n" + prompt_string + "\nmodel\n"; - if (abs_pos > 0) { - // multi-turn dialogue continuation. - prompt_string = "\n" + prompt_string; - } else { - HWY_DASSERT(abs_pos == 0); - if (gcpp::kSystemPrompt) { - prompt_string = - "system\nYou are a large language model built by " - "Google.\n" + - prompt_string; - } - } - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); - prompt_size = prompt.size(); - - // generate prompt - GenerateGemma(model, inference, prompt, abs_pos, pool, inner_pool, - stream_token, accept_token, gen, app.verbosity); - - std::string response_string = response_stream.str(); - if (!csv_writer.WriteRecord({record[csv.prompt_column], response_string})) { - fprintf(stderr, "Failed to write CSV: %s\n", - csv_writer.status().message().data()); - } - - response_stream.str(std::string()); // reset stream - response_stream.clear(); - abs_pos = 0; - } - - if (!csv_reader.Close()) { - fprintf(stderr, "Failed to close the CSV reader\n"); - } - if (!csv_writer.Close()) { - fprintf(stderr, "Failed to close the CSV writer\n"); - } -} - -void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, - CsvArgs& csv) { - PROFILER_ZONE("Run.misc"); - - hwy::ThreadPool inner_pool(0); - hwy::ThreadPool pool(app.num_threads); - // For many-core, pinning threads to cores helps. - if (app.num_threads > 10) { - pool.Run(0, pool.NumThreads(), - [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); - } - - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, - loader.ModelType(), loader.ModelTraining(), pool); - - if (csv.input_csv.path.empty()) { - HWY_ABORT("Need to specify csv file."); - } - - FileGemma(model, inference, app, csv, pool, inner_pool, - [](int) { return true; }); -} - -} // namespace gcpp - -int main(int argc, char** argv) { - { - PROFILER_ZONE("Startup.misc"); - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); - gcpp::CsvArgs csv(argc, argv); - - if (const char* error = loader.Validate()) { - loader.Help(); - HWY_ABORT("Invalid args: %s", error); - } - - gcpp::Run(loader, inference, app, csv); - } - PROFILER_PRINT_RESULTS(); // Must call outside the zone above. - return 0; -} diff --git a/util/app.h b/util/app.h index 296ec9a..8161419 100644 --- a/util/app.h +++ b/util/app.h @@ -133,7 +133,7 @@ struct LoaderArgs : public ArgsBase { if (tokenizer.path.empty()) { return "Missing --tokenizer flag, a file for the tokenizer is required."; } - if (!tokenizer.exists()) { + if (!tokenizer.Exists()) { return "Can't open file specified with --tokenizer flag."; } if (!compressed_weights.path.empty()) { @@ -148,7 +148,7 @@ struct LoaderArgs : public ArgsBase { if (weights.path.empty()) { return "Missing --weights flag, a file for the model weights."; } - if (!weights.exists()) { + if (!weights.Exists()) { return "Can't open file specified with --weights flag."; } return nullptr; diff --git a/util/args.h b/util/args.h index 7b17c99..f5ed602 100644 --- a/util/args.h +++ b/util/args.h @@ -23,16 +23,9 @@ #include // std::transform #include +#include "compression/io.h" #include "hwy/base.h" // HWY_ABORT -#if defined(_WIN32) -#include -#define F_OK 0 -#define access _access -#else -#include -#endif - namespace gcpp { // Wrapper for strings representing a path name. Differentiates vs. arbitrary @@ -57,9 +50,10 @@ struct Path { return path; } - // Beware, TOCTOU. - bool exists() const { - return (access(path.c_str(), F_OK) == 0); + // Returns whether the file existed when this was called. + bool Exists() const { + File file; + return file.Open(path.c_str(), "r"); } std::string path;