From e9a0caed87ac9fe7427f66b60ddcb483e2ff2daa Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 19 Apr 2024 00:39:22 -0700 Subject: [PATCH] Further improve IO, enable multiple backends without -D. Move Path into io.h and use for opening files. Removes dependency of gemma_lib on args. Separate Windows codepath instead of emulating POSIX functions. Plus lint fixes. PiperOrigin-RevId: 626279004 --- BUILD.bazel | 40 ++++---- CMakeLists.txt | 1 + compression/BUILD | 27 ++---- compression/blob_store.cc | 22 +++-- compression/blob_store.h | 7 +- compression/compress-inl.h | 6 +- compression/compress.h | 5 +- compression/io.cc | 189 ++++++++++++------------------------- compression/io.h | 62 +++++++++--- compression/io_win.cc | 115 ++++++++++++++++++++++ compression/stats.h | 9 +- gemma/gemma.cc | 10 +- gemma/gemma.h | 4 +- util/args.h | 31 ------ 14 files changed, 290 insertions(+), 238 deletions(-) create mode 100644 compression/io_win.cc diff --git a/BUILD.bazel b/BUILD.bazel index 8f17a6a..47e8175 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -51,15 +51,6 @@ cc_test( ], ) -cc_library( - name = "args", - hdrs = ["util/args.h"], - deps = [ - "//compression:io", - "@hwy//:hwy", - ], -) - cc_library( name = "gemma_lib", srcs = [ @@ -70,10 +61,10 @@ cc_library( "gemma/gemma.h", ], deps = [ - ":args", ":ops", # "//base", "//compression:compress", + "//compression:io", "@hwy//:hwy", "@hwy//:matvec", "@hwy//:nanobenchmark", # timer @@ -83,6 +74,25 @@ cc_library( ], ) +cc_library( + name = "args", + hdrs = ["util/args.h"], + deps = [ + "//compression:io", + "@hwy//:hwy", + ], +) + +cc_library( + name = "app", + hdrs = ["util/app.h"], + deps = [ + ":args", + ":gemma_lib", + "@hwy//:hwy", + ], +) + cc_test( name = "gemma_test", srcs = ["gemma/gemma_test.cc"], @@ -102,16 +112,6 @@ cc_test( ], ) -cc_library( - name = "app", - hdrs = ["util/app.h"], - deps = [ - ":args", - ":gemma_lib", - "@hwy//:hwy", - ], -) - cc_binary( name = "gemma", srcs = ["gemma/run.cc"], diff --git a/CMakeLists.txt b/CMakeLists.txt index fed1619..a439e3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,7 @@ set(SOURCES compression/blob_store.h compression/compress.h compression/compress-inl.h + compression/io_win.cc compression/io.cc compression/io.h compression/nuq.h diff --git a/compression/BUILD b/compression/BUILD index ecf067a..8c090fb 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -12,9 +12,11 @@ package( cc_library( name = "io", - srcs = ["io.cc"], + srcs = [ + "io.cc", + # Placeholder for io backend, do not remove + ], hdrs = ["io.h"], - # Placeholder for io textual_hdrs, do not remove deps = [ # Placeholder for io deps, do not remove "@hwy//:hwy", @@ -80,12 +82,8 @@ cc_library( cc_library( name = "sfp", - hdrs = [ - "sfp.h", - ], - textual_hdrs = [ - "sfp-inl.h", - ], + hdrs = ["sfp.h"], + textual_hdrs = ["sfp-inl.h"], deps = [ "@hwy//:hwy", ], @@ -112,12 +110,8 @@ cc_test( cc_library( name = "nuq", - hdrs = [ - "nuq.h", - ], - textual_hdrs = [ - "nuq-inl.h", - ], + hdrs = ["nuq.h"], + textual_hdrs = ["nuq-inl.h"], deps = [ ":sfp", "@hwy//:hwy", @@ -158,6 +152,7 @@ cc_library( deps = [ ":blob_store", ":distortion", + ":io", ":nuq", ":sfp", ":stats", @@ -170,9 +165,7 @@ cc_library( # For internal experimentation cc_library( name = "analyze", - textual_hdrs = [ - "analyze.h", - ], + textual_hdrs = ["analyze.h"], deps = [ ":distortion", ":nuq", diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 6dea7df..49565d4 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include "compression/io.h" @@ -199,12 +200,13 @@ class BlobStore { }; #pragma pack(pop) -BlobError BlobReader::Open(const char* filename) { - if (!file_.Open(filename, "r")) return __LINE__; +BlobError BlobReader::Open(const Path& filename) { + file_ = OpenFileOrNull(filename, "r"); + if (!file_) return __LINE__; // Read first part of header to get actual size. BlobStore bs; - if (!file_.Read(0, sizeof(bs), &bs)) return __LINE__; + if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__; const size_t padded_size = bs.PaddedHeaderSize(); HWY_ASSERT(padded_size >= sizeof(bs)); @@ -216,11 +218,11 @@ BlobError BlobReader::Open(const char* filename) { hwy::CopySameSize(&bs, blob_store_.get()); // Read the rest of the header, but not the full file. uint8_t* bytes = reinterpret_cast(blob_store_.get()); - if (!file_.Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) { + if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) { return __LINE__; } - return blob_store_->CheckValidity(file_.FileSize()); + return blob_store_->CheckValidity(file_->FileSize()); } BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { @@ -247,7 +249,7 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { // between consecutive runs. // - memory-mapped I/O is less predictable and adds noise to measurements. BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { - File* pfile = &file_; // not owned + File* pfile = file_.get(); // not owned const auto& requests = requests_; std::atomic_flag err = ATOMIC_FLAG_INIT; // >5x speedup from parallel reads when cached. @@ -262,7 +264,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { return 0; } -BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) { +BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) { HWY_ASSERT(keys_.size() == blobs_.size()); // Concatenate blobs in memory. @@ -273,9 +275,9 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) { keys_.data(), blobs_.data(), keys_.size(), bs.get()); // Create/replace existing file. - File file; - if (!file.Open(filename, "w+")) return __LINE__; - File* pfile = &file; // not owned + std::unique_ptr file = OpenFileOrNull(filename, "w+"); + if (!file) return __LINE__; + File* pfile = file.get(); // not owned std::atomic_flag err = ATOMIC_FLAG_INIT; pool.Run(0, requests.size(), diff --git a/compression/blob_store.h b/compression/blob_store.h index 8e712c2..d95d7e1 100644 --- a/compression/blob_store.h +++ b/compression/blob_store.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "compression/io.h" @@ -63,7 +64,7 @@ class BlobReader { ~BlobReader() = default; // Opens `filename` and reads its header. - BlobError Open(const char* filename); + BlobError Open(const Path& filename); // Enqueues read requests if `key` is found and its size matches `size`. BlobError Enqueue(hwy::uint128_t key, void* data, size_t size); @@ -74,7 +75,7 @@ class BlobReader { private: BlobStorePtr blob_store_; // holds header, not the entire file std::vector requests_; - File file_; + std::unique_ptr file_; }; class BlobWriter { @@ -85,7 +86,7 @@ class BlobWriter { } // Stores all blobs to disk in the given order with padding for alignment. - BlobError WriteAll(hwy::ThreadPool& pool, const char* filename); + BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename); private: std::vector keys_; diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 516e9e3..a6a4b7e 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -459,11 +459,11 @@ class Compressor { } } - void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) { + void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) { const BlobError err = writer_.WriteAll(pool, blob_filename); if (err != 0) { - fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename, - err); + fprintf(stderr, "Failed to write blobs to %s (error %d)\n", + blob_filename.path.c_str(), err); } } diff --git a/compression/compress.h b/compression/compress.h index 5c9a3b2..549ea6f 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -28,6 +28,7 @@ // IWYU pragma: begin_exports #include "compression/blob_store.h" +#include "compression/io.h" #include "compression/nuq.h" #include "compression/sfp.h" // IWYU pragma: end_exports @@ -166,13 +167,13 @@ hwy::uint128_t CacheKey(const char* name) { class CacheLoader { public: - explicit CacheLoader(const char* blob_filename) { + explicit CacheLoader(const Path& blob_filename) { err_ = reader_.Open(blob_filename); if (err_ != 0) { fprintf(stderr, "Cached compressed weights does not exist yet (code %d), " "compressing weights and creating file: %s.\n", - err_, blob_filename); + err_, blob_filename.path.c_str()); } } diff --git a/compression/io.cc b/compression/io.cc index 4b32f6b..84e3603 100644 --- a/compression/io.cc +++ b/compression/io.cc @@ -14,12 +14,10 @@ // limitations under the License. // Safe to be first, does not include POSIX headers. -#include "compression/io.h" - -// 1.5x slowdown vs. POSIX (200 ms longer startup), hence opt-in. -#ifdef GEMMA_IO_GOOGLE -#include "compression/io_google.cc" -#else +#include "hwy/detect_compiler_arch.h" +// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to +// check this in source code because we support multiple build systems. +#if !HWY_OS_WIN // Request POSIX 2008, including `pread()` and `posix_fadvise()`. #if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700 @@ -39,150 +37,85 @@ #include #include // SEEK_END - unistd isn't enough for IDE. #include // O_RDONLY +#include // read, write, close +#include + +#include "compression/io.h" #include "hwy/base.h" // HWY_ASSERT -#include "hwy/detect_compiler_arch.h" -#if HWY_OS_WIN -#include -#include // read, write, close -#else -#include // read, write, close -#endif namespace gcpp { -// Emulate missing POSIX functions. -#if HWY_OS_WIN -namespace { +class FilePosix : public File { + int fd_ = 0; -static inline int open(const char* filename, int flags, int mode = 0) { - const bool is_read = (flags & _O_RDONLY) != 0; - const DWORD win_flags = - FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0); - const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE; - const DWORD share = is_read ? FILE_SHARE_READ : 0; - const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS; - const HANDLE file = - CreateFileA(filename, access, share, nullptr, create, win_flags, nullptr); - if (file == INVALID_HANDLE_VALUE) return -1; - return _open_osfhandle(reinterpret_cast(file), flags); -} - -static inline off_t lseek(int fd, off_t offset, int whence) { - return _lseeki64(fd, offset, whence); -} - -static inline int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) { - HANDLE file = reinterpret_cast(_get_osfhandle(fd)); - if (file == INVALID_HANDLE_VALUE) { - return -1; - } - - OVERLAPPED overlapped = {0}; - overlapped.Offset = offset & 0xFFFFFFFF; - overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; - - DWORD bytes_read; - if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) { - if (GetLastError() != ERROR_HANDLE_EOF) { - return -1; + public: + explicit FilePosix(int fd) : fd_(fd) { HWY_ASSERT(fd > 0); } + ~FilePosix() override { + if (fd_ != 0) { + HWY_ASSERT(close(fd_) != -1); } } - return bytes_read; -} - -static inline int64_t pwrite(int fd, const void* buf, uint64_t size, - uint64_t offset) { - HANDLE file = reinterpret_cast(_get_osfhandle(fd)); - if (file == INVALID_HANDLE_VALUE) { - return -1; - } - - OVERLAPPED overlapped = {0}; - overlapped.Offset = offset & 0xFFFFFFFF; - overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; - - DWORD bytes_written; - if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) { - if (GetLastError() != ERROR_HANDLE_EOF) { - return -1; + uint64_t FileSize() const override { + static_assert(sizeof(off_t) == 8, "64-bit off_t required"); + const off_t size = lseek(fd_, 0, SEEK_END); + if (size < 0) { + return 0; } + return static_cast(size); } - return bytes_written; -} + bool Read(uint64_t offset, uint64_t size, void* to) const override { + uint8_t* bytes = reinterpret_cast(to); + uint64_t pos = 0; + for (;;) { + // pread seems to be faster than lseek + read when parallelized. + const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos); + if (bytes_read <= 0) break; + pos += bytes_read; + HWY_ASSERT(pos <= size); + if (pos == size) break; + } + return pos == size; // success if managed to read desired size + } -} // namespace -#endif // HWY_OS_WIN + bool Write(const void* from, uint64_t size, uint64_t offset) override { + const uint8_t* bytes = reinterpret_cast(from); + uint64_t pos = 0; + for (;;) { + const auto bytes_written = + pwrite(fd_, bytes + pos, size - pos, offset + pos); + if (bytes_written <= 0) break; + pos += bytes_written; + HWY_ASSERT(pos <= size); + if (pos == size) break; + } + return pos == size; // success if managed to write desired size + } +}; // FilePosix + +HWY_MAYBE_UNUSED extern std::unique_ptr OpenFileGoogle( + const Path& filename, const char* mode); + +std::unique_ptr OpenFileOrNull(const Path& filename, const char* mode) { + std::unique_ptr file; // OpenFileGoogle omitted + if (file) return file; -bool File::Open(const char* filename, const char* mode) { const bool is_read = mode[0] != 'w'; const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC; - int fd = open(filename, flags, 0644); - if (fd < 0) { - p_ = 0; - return false; - } + const int fd = open(filename.path.c_str(), flags, 0644); + if (fd < 0) return file; - if (is_read) { #if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21) + if (is_read) { // Doubles the readahead window, which seems slightly faster when cached. (void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL); + } #endif - } - p_ = static_cast(fd); - return true; -} - -void File::Close() { - const int fd = static_cast(p_); - if (fd > 0) { - HWY_ASSERT(close(fd) != -1); - p_ = 0; - } -} - -uint64_t File::FileSize() const { - static_assert(sizeof(off_t) == 8, "64-bit off_t required"); - const int fd = static_cast(p_); - const off_t size = lseek(fd, 0, SEEK_END); - if (size < 0) { - return 0; - } - return static_cast(size); -} - -bool File::Read(uint64_t offset, uint64_t size, void* to) const { - const int fd = static_cast(p_); - uint8_t* bytes = reinterpret_cast(to); - uint64_t pos = 0; - for (;;) { - // pread seems to be faster than lseek + read when parallelized. - const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos); - if (bytes_read <= 0) break; - pos += bytes_read; - HWY_ASSERT(pos <= size); - if (pos == size) break; - } - return pos == size; // success if managed to read desired size -} - -bool File::Write(const void* from, uint64_t size, uint64_t offset) { - const int fd = static_cast(p_); - const uint8_t* bytes = reinterpret_cast(from); - uint64_t pos = 0; - for (;;) { - const auto bytes_written = - pwrite(fd, bytes + pos, size - pos, offset + pos); - if (bytes_written <= 0) break; - pos += bytes_written; - HWY_ASSERT(pos <= size); - if (pos == size) break; - } - return pos == size; // success if managed to write desired size + return std::make_unique(fd); } } // namespace gcpp -#endif // GEMMA_IO_GOOGLE +#endif // !HWY_OS_WIN diff --git a/compression/io.h b/compression/io.h index 29b429d..c5287b8 100644 --- a/compression/io.h +++ b/compression/io.h @@ -16,35 +16,71 @@ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ +#include #include +#include +#include +#include // std::move + namespace gcpp { -// unique_ptr-like interface with RAII, but not (yet) moveable. +// Forward-declare to break the circular dependency: OpenFileOrNull returns +// File and has a Path argument, and Path::Exists calls OpenFileOrNull. We +// prefer to define Exists inline because there are multiple io*.cc files. +struct Path; + +// Abstract base class enables multiple I/O backends in the same binary. class File { public: File() = default; - ~File() { Close(); } + virtual ~File() = default; + + // Noncopyable. File(const File& other) = delete; const File& operator=(const File& other) = delete; - // Returns false on failure. `mode` is either "r" or "w+". - bool Open(const char* filename, const char* mode); - - // No effect if `Open` returned false or `Close` already called. - void Close(); - // Returns size in bytes or 0. - uint64_t FileSize() const; + virtual uint64_t FileSize() const = 0; // Returns true if all the requested bytes were read. - bool Read(uint64_t offset, uint64_t size, void* to) const; + virtual bool Read(uint64_t offset, uint64_t size, void* to) const = 0; // Returns true if all the requested bytes were written. - bool Write(const void* from, uint64_t size, uint64_t offset); + virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0; +}; - private: - intptr_t p_ = 0; +// Returns nullptr on failure. `mode` is either "r" or "w+". This is not just +// named 'OpenFile' to avoid a conflict with Windows.h #define. +std::unique_ptr OpenFileOrNull(const Path& filename, const char* mode); + +// Wrapper for strings representing a path name. Differentiates vs. arbitrary +// strings and supports shortening for display purposes. +struct Path { + Path() {} + explicit Path(const char* p) : path(p) {} + explicit Path(std::string p) : path(std::move(p)) {} + + Path& operator=(const char* other) { + path = other; + return *this; + } + + std::string Shortened() const { + constexpr size_t kMaxLen = 48; + constexpr size_t kCutPoint = kMaxLen / 2 - 5; + if (path.size() > kMaxLen) { + return std::string(begin(path), begin(path) + kCutPoint) + " ... " + + std::string(end(path) - kCutPoint, end(path)); + } + if (path.empty()) return "[no path specified]"; + return path; + } + + // Returns whether the file existed when this was called. + bool Exists() const { return !!OpenFileOrNull(*this, "r"); } + + std::string path; }; } // namespace gcpp diff --git a/compression/io_win.cc b/compression/io_win.cc new file mode 100644 index 0000000..1cb1673 --- /dev/null +++ b/compression/io_win.cc @@ -0,0 +1,115 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/detect_compiler_arch.h" +// Only compile this file on Windows; it replaces io.cc. It is easier to check +// this in source code because we support multiple build systems. +#if HWY_OS_WIN + +#include +#include + +#include "compression/io.h" +#include "hwy/base.h" // HWY_ASSERT +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef VC_EXTRALEAN +#define VC_EXTRALEAN +#endif +#include + +namespace gcpp { + +class FileWin : public File { + HANDLE hFile_ = INVALID_HANDLE_VALUE; + + public: + FileWin(HANDLE hFile) : hFile_(hFile) { + HWY_ASSERT(hFile != INVALID_HANDLE_VALUE); + } + ~FileWin() override { + if (hFile_ != INVALID_HANDLE_VALUE) { + HWY_ASSERT(CloseHandle(hFile_) != 0); + } + } + + uint64_t FileSize() const override { + DWORD hi; + const DWORD lo = GetFileSize(hFile_, &hi); + if (lo == INVALID_FILE_SIZE) return 0; + return (static_cast(hi) << 32) | lo; + } + + bool Read(uint64_t offset, uint64_t size, void* to) const override { + uint8_t* bytes = reinterpret_cast(to); + OVERLAPPED overlapped = {0}; + // Loop is required because ReadFile[Ex] size argument is 32-bit. + while (size != 0) { + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + const DWORD want = + static_cast(HWY_MIN(size, uint64_t{0xFFFFFFFF})); + DWORD got; + if (!ReadFile(hFile_, bytes, want, &got, &overlapped)) { + if (GetLastError() != ERROR_HANDLE_EOF) { + return false; + } + } + offset += got; + bytes += got; + size -= got; + } + return true; // read everything => success + } + + bool Write(const void* from, uint64_t size, uint64_t offset) override { + const uint8_t* bytes = reinterpret_cast(from); + OVERLAPPED overlapped = {0}; + // Loop is required because WriteFile[Ex] size argument is 32-bit. + while (size != 0) { + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + const DWORD want = + static_cast(HWY_MIN(size, uint64_t{0xFFFFFFFF})); + DWORD got; + if (!WriteFile(hFile_, bytes, want, &got, &overlapped)) { + if (GetLastError() != ERROR_HANDLE_EOF) { + return false; + } + } + offset += got; + bytes += got; + size -= got; + } + return true; // wrote everything => success + } +}; // FileWin + +std::unique_ptr OpenFileOrNull(const Path& filename, const char* mode) { + const bool is_read = mode[0] != 'w'; + const DWORD flags = + FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0); + const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE; + const DWORD share = is_read ? FILE_SHARE_READ : 0; + const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS; + const HANDLE hFile = CreateFileA(filename.path.c_str(), access, share, + nullptr, create, flags, nullptr); + if (hFile == INVALID_HANDLE_VALUE) return std::unique_ptr(); + return std::make_unique(hFile); +} + +} // namespace gcpp +#endif // HWY_OS_WIN diff --git a/compression/stats.h b/compression/stats.h index 12985f4..1f1beb3 100644 --- a/compression/stats.h +++ b/compression/stats.h @@ -19,7 +19,6 @@ #include #include -#include #include #include @@ -77,8 +76,8 @@ class Stats { void Notify(const float x) { ++n_; - min_ = std::min(min_, x); - max_ = std::max(max_, x); + min_ = HWY_MIN(min_, x); + max_ = HWY_MAX(max_, x); product_ *= x; @@ -119,7 +118,7 @@ class Stats { // Near zero for normal distributions; if positive on a unimodal distribution, // the right tail is fatter. Assumes n_ is large. double SampleSkewness() const { - if (std::abs(m2_) < 1E-7) return 0.0; + if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0; return m3_ * std::sqrt(static_cast(n_)) / std::pow(m2_, 1.5); } // Corrected for bias (same as Wikipedia and Minitab but not Excel). @@ -132,7 +131,7 @@ class Stats { // Near zero for normal distributions; smaller values indicate fewer/smaller // outliers and larger indicates more/larger outliers. Assumes n_ is large. double SampleKurtosis() const { - if (std::abs(m2_) < 1E-7) return 0.0; + if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0; return m4_ * n_ / (m2_ * m2_); } // Corrected for bias (same as Wikipedia and Minitab but not Excel). diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 99c4d53..f7d3daa 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -39,18 +39,20 @@ #include #include +#include #include #include #include #include #include // NOLINT #include +#include #include #include "compression/compress.h" +#include "compression/io.h" // Path #include "gemma/configs.h" #include "gemma/gemma.h" -#include "util/args.h" // Path #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -914,7 +916,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, } } - // TODO: sink the loop into these functions, i.e. make them matmuls. + // TODO: sink the loop into these functions, i.e. make them MatMul. pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t thread_id) HWY_ATTR { @@ -1331,7 +1333,7 @@ hwy::AlignedFreeUniquePtr LoadCompressedWeights( new (&c_weights->c_layer_ptrs) CompressedLayerPointers(pool); std::array scales; - CacheLoader loader(weights.path.c_str()); + CacheLoader loader(weights); ForEachTensor(nullptr, *c_weights, loader); loader.LoadScales(scales.data(), scales.size()); if (!loader.ReadAll(pool)) { @@ -1415,7 +1417,7 @@ void CompressWeights(const Path& weights_path, Compressor compressor(pool); ForEachTensor(weights, *c_weights, compressor); compressor.AddScales(weights->scales.data(), weights->scales.size()); - compressor.WriteAll(pool, compressed_weights_path.path.c_str()); + compressor.WriteAll(pool, compressed_weights_path); weights->layer_ptrs.~LayerPointers(); c_weights->c_layer_ptrs.~CompressedLayerPointers(); diff --git a/gemma/gemma.h b/gemma/gemma.h index da37488..60d7843 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -22,8 +22,8 @@ #include #include +#include "compression/io.h" // Path #include "gemma/configs.h" -#include "util/args.h" // Path #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" @@ -35,7 +35,7 @@ using EmbedderInputT = hwy::bfloat16_t; // Will be called for layers output with: // - position in the tokens sequence // - name of the data, p.ex. "tokens", "block.1", "final_norm" -// - ponter to the data array +// - pointer to the data array // - size of the data array using LayersOutputT = std::function; diff --git a/util/args.h b/util/args.h index f5ed602..b50cbaa 100644 --- a/util/args.h +++ b/util/args.h @@ -28,37 +28,6 @@ namespace gcpp { -// Wrapper for strings representing a path name. Differentiates vs. arbitrary -// strings and supports shortening for display purposes. -struct Path { - Path() {} - explicit Path(const char* p) : path(p) {} - - Path& operator=(const char* other) { - path = other; - return *this; - } - - std::string Shortened() const { - constexpr size_t kMaxLen = 48; - constexpr size_t kCutPoint = kMaxLen / 2 - 5; - if (path.size() > kMaxLen) { - return std::string(begin(path), begin(path) + kCutPoint) + " ... " + - std::string(end(path) - kCutPoint, end(path)); - } - if (path.empty()) return "[no path specified]"; - return path; - } - - // Returns whether the file existed when this was called. - bool Exists() const { - File file; - return file.Open(path.c_str(), "r"); - } - - std::string path; -}; - // Args is a class that provides a ForEach member function which visits each of // its member variables. ArgsBase provides functions called by Args to // initialize values to their defaults (passed as an argument to the visitor),