Further improve IO, enable multiple backends without -D.

Move Path into io.h and use for opening files.
Removes dependency of gemma_lib on args.
Separate Windows codepath instead of emulating POSIX functions.

Plus lint fixes.

PiperOrigin-RevId: 626279004
This commit is contained in:
Jan Wassenberg 2024-04-19 00:39:22 -07:00 committed by Copybara-Service
parent 38f1ea9b80
commit e9a0caed87
14 changed files with 290 additions and 238 deletions

View File

@ -51,15 +51,6 @@ cc_test(
], ],
) )
cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
"//compression:io",
"@hwy//:hwy",
],
)
cc_library( cc_library(
name = "gemma_lib", name = "gemma_lib",
srcs = [ srcs = [
@ -70,10 +61,10 @@ cc_library(
"gemma/gemma.h", "gemma/gemma.h",
], ],
deps = [ deps = [
":args",
":ops", ":ops",
# "//base", # "//base",
"//compression:compress", "//compression:compress",
"//compression:io",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:matvec", "@hwy//:matvec",
"@hwy//:nanobenchmark", # timer "@hwy//:nanobenchmark", # timer
@ -83,6 +74,25 @@ cc_library(
], ],
) )
cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
"//compression:io",
"@hwy//:hwy",
],
)
cc_library(
name = "app",
hdrs = ["util/app.h"],
deps = [
":args",
":gemma_lib",
"@hwy//:hwy",
],
)
cc_test( cc_test(
name = "gemma_test", name = "gemma_test",
srcs = ["gemma/gemma_test.cc"], srcs = ["gemma/gemma_test.cc"],
@ -102,16 +112,6 @@ cc_test(
], ],
) )
cc_library(
name = "app",
hdrs = ["util/app.h"],
deps = [
":args",
":gemma_lib",
"@hwy//:hwy",
],
)
cc_binary( cc_binary(
name = "gemma", name = "gemma",
srcs = ["gemma/run.cc"], srcs = ["gemma/run.cc"],

View File

@ -38,6 +38,7 @@ set(SOURCES
compression/blob_store.h compression/blob_store.h
compression/compress.h compression/compress.h
compression/compress-inl.h compression/compress-inl.h
compression/io_win.cc
compression/io.cc compression/io.cc
compression/io.h compression/io.h
compression/nuq.h compression/nuq.h

View File

@ -12,9 +12,11 @@ package(
cc_library( cc_library(
name = "io", name = "io",
srcs = ["io.cc"], srcs = [
"io.cc",
# Placeholder for io backend, do not remove
],
hdrs = ["io.h"], hdrs = ["io.h"],
# Placeholder for io textual_hdrs, do not remove
deps = [ deps = [
# Placeholder for io deps, do not remove # Placeholder for io deps, do not remove
"@hwy//:hwy", "@hwy//:hwy",
@ -80,12 +82,8 @@ cc_library(
cc_library( cc_library(
name = "sfp", name = "sfp",
hdrs = [ hdrs = ["sfp.h"],
"sfp.h", textual_hdrs = ["sfp-inl.h"],
],
textual_hdrs = [
"sfp-inl.h",
],
deps = [ deps = [
"@hwy//:hwy", "@hwy//:hwy",
], ],
@ -112,12 +110,8 @@ cc_test(
cc_library( cc_library(
name = "nuq", name = "nuq",
hdrs = [ hdrs = ["nuq.h"],
"nuq.h", textual_hdrs = ["nuq-inl.h"],
],
textual_hdrs = [
"nuq-inl.h",
],
deps = [ deps = [
":sfp", ":sfp",
"@hwy//:hwy", "@hwy//:hwy",
@ -158,6 +152,7 @@ cc_library(
deps = [ deps = [
":blob_store", ":blob_store",
":distortion", ":distortion",
":io",
":nuq", ":nuq",
":sfp", ":sfp",
":stats", ":stats",
@ -170,9 +165,7 @@ cc_library(
# For internal experimentation # For internal experimentation
cc_library( cc_library(
name = "analyze", name = "analyze",
textual_hdrs = [ textual_hdrs = ["analyze.h"],
"analyze.h",
],
deps = [ deps = [
":distortion", ":distortion",
":nuq", ":nuq",

View File

@ -19,6 +19,7 @@
#include <stdint.h> #include <stdint.h>
#include <atomic> #include <atomic>
#include <memory>
#include <vector> #include <vector>
#include "compression/io.h" #include "compression/io.h"
@ -199,12 +200,13 @@ class BlobStore {
}; };
#pragma pack(pop) #pragma pack(pop)
BlobError BlobReader::Open(const char* filename) { BlobError BlobReader::Open(const Path& filename) {
if (!file_.Open(filename, "r")) return __LINE__; file_ = OpenFileOrNull(filename, "r");
if (!file_) return __LINE__;
// Read first part of header to get actual size. // Read first part of header to get actual size.
BlobStore bs; BlobStore bs;
if (!file_.Read(0, sizeof(bs), &bs)) return __LINE__; if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__;
const size_t padded_size = bs.PaddedHeaderSize(); const size_t padded_size = bs.PaddedHeaderSize();
HWY_ASSERT(padded_size >= sizeof(bs)); HWY_ASSERT(padded_size >= sizeof(bs));
@ -216,11 +218,11 @@ BlobError BlobReader::Open(const char* filename) {
hwy::CopySameSize(&bs, blob_store_.get()); hwy::CopySameSize(&bs, blob_store_.get());
// Read the rest of the header, but not the full file. // Read the rest of the header, but not the full file.
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get()); uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
if (!file_.Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) { if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
return __LINE__; return __LINE__;
} }
return blob_store_->CheckValidity(file_.FileSize()); return blob_store_->CheckValidity(file_->FileSize());
} }
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
@ -247,7 +249,7 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
// between consecutive runs. // between consecutive runs.
// - memory-mapped I/O is less predictable and adds noise to measurements. // - memory-mapped I/O is less predictable and adds noise to measurements.
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
File* pfile = &file_; // not owned File* pfile = file_.get(); // not owned
const auto& requests = requests_; const auto& requests = requests_;
std::atomic_flag err = ATOMIC_FLAG_INIT; std::atomic_flag err = ATOMIC_FLAG_INIT;
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
@ -262,7 +264,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
return 0; return 0;
} }
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) { BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
HWY_ASSERT(keys_.size() == blobs_.size()); HWY_ASSERT(keys_.size() == blobs_.size());
// Concatenate blobs in memory. // Concatenate blobs in memory.
@ -273,9 +275,9 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) {
keys_.data(), blobs_.data(), keys_.size(), bs.get()); keys_.data(), blobs_.data(), keys_.size(), bs.get());
// Create/replace existing file. // Create/replace existing file.
File file; std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
if (!file.Open(filename, "w+")) return __LINE__; if (!file) return __LINE__;
File* pfile = &file; // not owned File* pfile = file.get(); // not owned
std::atomic_flag err = ATOMIC_FLAG_INIT; std::atomic_flag err = ATOMIC_FLAG_INIT;
pool.Run(0, requests.size(), pool.Run(0, requests.size(),

View File

@ -19,6 +19,7 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <memory>
#include <vector> #include <vector>
#include "compression/io.h" #include "compression/io.h"
@ -63,7 +64,7 @@ class BlobReader {
~BlobReader() = default; ~BlobReader() = default;
// Opens `filename` and reads its header. // Opens `filename` and reads its header.
BlobError Open(const char* filename); BlobError Open(const Path& filename);
// Enqueues read requests if `key` is found and its size matches `size`. // Enqueues read requests if `key` is found and its size matches `size`.
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size); BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
@ -74,7 +75,7 @@ class BlobReader {
private: private:
BlobStorePtr blob_store_; // holds header, not the entire file BlobStorePtr blob_store_; // holds header, not the entire file
std::vector<BlobIO> requests_; std::vector<BlobIO> requests_;
File file_; std::unique_ptr<File> file_;
}; };
class BlobWriter { class BlobWriter {
@ -85,7 +86,7 @@ class BlobWriter {
} }
// Stores all blobs to disk in the given order with padding for alignment. // Stores all blobs to disk in the given order with padding for alignment.
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename); BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
private: private:
std::vector<hwy::uint128_t> keys_; std::vector<hwy::uint128_t> keys_;

View File

@ -459,11 +459,11 @@ class Compressor {
} }
} }
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) { void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename); const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) { if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename, fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
err); blob_filename.path.c_str(), err);
} }
} }

View File

@ -28,6 +28,7 @@
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "compression/blob_store.h" #include "compression/blob_store.h"
#include "compression/io.h"
#include "compression/nuq.h" #include "compression/nuq.h"
#include "compression/sfp.h" #include "compression/sfp.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
@ -166,13 +167,13 @@ hwy::uint128_t CacheKey(const char* name) {
class CacheLoader { class CacheLoader {
public: public:
explicit CacheLoader(const char* blob_filename) { explicit CacheLoader(const Path& blob_filename) {
err_ = reader_.Open(blob_filename); err_ = reader_.Open(blob_filename);
if (err_ != 0) { if (err_ != 0) {
fprintf(stderr, fprintf(stderr,
"Cached compressed weights does not exist yet (code %d), " "Cached compressed weights does not exist yet (code %d), "
"compressing weights and creating file: %s.\n", "compressing weights and creating file: %s.\n",
err_, blob_filename); err_, blob_filename.path.c_str());
} }
} }

View File

@ -14,12 +14,10 @@
// limitations under the License. // limitations under the License.
// Safe to be first, does not include POSIX headers. // Safe to be first, does not include POSIX headers.
#include "compression/io.h" #include "hwy/detect_compiler_arch.h"
// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to
// 1.5x slowdown vs. POSIX (200 ms longer startup), hence opt-in. // check this in source code because we support multiple build systems.
#ifdef GEMMA_IO_GOOGLE #if !HWY_OS_WIN
#include "compression/io_google.cc"
#else
// Request POSIX 2008, including `pread()` and `posix_fadvise()`. // Request POSIX 2008, including `pread()` and `posix_fadvise()`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700 #if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
@ -39,150 +37,85 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE. #include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY #include <sys/stat.h> // O_RDONLY
#include <unistd.h> // read, write, close
#include <memory>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
#include "hwy/detect_compiler_arch.h"
#if HWY_OS_WIN
#include <fileapi.h>
#include <io.h> // read, write, close
#else
#include <unistd.h> // read, write, close
#endif
namespace gcpp { namespace gcpp {
// Emulate missing POSIX functions. class FilePosix : public File {
#if HWY_OS_WIN int fd_ = 0;
namespace {
static inline int open(const char* filename, int flags, int mode = 0) { public:
const bool is_read = (flags & _O_RDONLY) != 0; explicit FilePosix(int fd) : fd_(fd) { HWY_ASSERT(fd > 0); }
const DWORD win_flags = ~FilePosix() override {
FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0); if (fd_ != 0) {
const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE; HWY_ASSERT(close(fd_) != -1);
const DWORD share = is_read ? FILE_SHARE_READ : 0;
const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS;
const HANDLE file =
CreateFileA(filename, access, share, nullptr, create, win_flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return -1;
return _open_osfhandle(reinterpret_cast<intptr_t>(file), flags);
}
static inline off_t lseek(int fd, off_t offset, int whence) {
return _lseeki64(fd, offset, whence);
}
static inline int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_read;
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
} }
} }
return bytes_read; uint64_t FileSize() const override {
} static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const off_t size = lseek(fd_, 0, SEEK_END);
static inline int64_t pwrite(int fd, const void* buf, uint64_t size, if (size < 0) {
uint64_t offset) { return 0;
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_written;
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
} }
return static_cast<uint64_t>(size);
} }
return bytes_written; bool Read(uint64_t offset, uint64_t size, void* to) const override {
} uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
} // namespace bool Write(const void* from, uint64_t size, uint64_t offset) override {
#endif // HWY_OS_WIN const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd_, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
}
}; // FilePosix
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
const Path& filename, const char* mode);
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
std::unique_ptr<File> file; // OpenFileGoogle omitted
if (file) return file;
bool File::Open(const char* filename, const char* mode) {
const bool is_read = mode[0] != 'w'; const bool is_read = mode[0] != 'w';
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC; const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
int fd = open(filename, flags, 0644); const int fd = open(filename.path.c_str(), flags, 0644);
if (fd < 0) { if (fd < 0) return file;
p_ = 0;
return false;
}
if (is_read) {
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21) #if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
if (is_read) {
// Doubles the readahead window, which seems slightly faster when cached. // Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL); (void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
}
#endif #endif
}
p_ = static_cast<intptr_t>(fd); return std::make_unique<FilePosix>(fd);
return true;
}
void File::Close() {
const int fd = static_cast<int>(p_);
if (fd > 0) {
HWY_ASSERT(close(fd) != -1);
p_ = 0;
}
}
uint64_t File::FileSize() const {
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const int fd = static_cast<int>(p_);
const off_t size = lseek(fd, 0, SEEK_END);
if (size < 0) {
return 0;
}
return static_cast<uint64_t>(size);
}
bool File::Read(uint64_t offset, uint64_t size, void* to) const {
const int fd = static_cast<int>(p_);
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
bool File::Write(const void* from, uint64_t size, uint64_t offset) {
const int fd = static_cast<int>(p_);
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
} }
} // namespace gcpp } // namespace gcpp
#endif // GEMMA_IO_GOOGLE #endif // !HWY_OS_WIN

View File

@ -16,35 +16,71 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
#include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <memory>
#include <string>
#include <utility> // std::move
namespace gcpp { namespace gcpp {
// unique_ptr-like interface with RAII, but not (yet) moveable. // Forward-declare to break the circular dependency: OpenFileOrNull returns
// File and has a Path argument, and Path::Exists calls OpenFileOrNull. We
// prefer to define Exists inline because there are multiple io*.cc files.
struct Path;
// Abstract base class enables multiple I/O backends in the same binary.
class File { class File {
public: public:
File() = default; File() = default;
~File() { Close(); } virtual ~File() = default;
// Noncopyable.
File(const File& other) = delete; File(const File& other) = delete;
const File& operator=(const File& other) = delete; const File& operator=(const File& other) = delete;
// Returns false on failure. `mode` is either "r" or "w+".
bool Open(const char* filename, const char* mode);
// No effect if `Open` returned false or `Close` already called.
void Close();
// Returns size in bytes or 0. // Returns size in bytes or 0.
uint64_t FileSize() const; virtual uint64_t FileSize() const = 0;
// Returns true if all the requested bytes were read. // Returns true if all the requested bytes were read.
bool Read(uint64_t offset, uint64_t size, void* to) const; virtual bool Read(uint64_t offset, uint64_t size, void* to) const = 0;
// Returns true if all the requested bytes were written. // Returns true if all the requested bytes were written.
bool Write(const void* from, uint64_t size, uint64_t offset); virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0;
};
private: // Returns nullptr on failure. `mode` is either "r" or "w+". This is not just
intptr_t p_ = 0; // named 'OpenFile' to avoid a conflict with Windows.h #define.
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode);
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
// strings and supports shortening for display purposes.
struct Path {
Path() {}
explicit Path(const char* p) : path(p) {}
explicit Path(std::string p) : path(std::move(p)) {}
Path& operator=(const char* other) {
path = other;
return *this;
}
std::string Shortened() const {
constexpr size_t kMaxLen = 48;
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
if (path.size() > kMaxLen) {
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
std::string(end(path) - kCutPoint, end(path));
}
if (path.empty()) return "[no path specified]";
return path;
}
// Returns whether the file existed when this was called.
bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
std::string path;
}; };
} // namespace gcpp } // namespace gcpp

115
compression/io_win.cc Normal file
View File

@ -0,0 +1,115 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "hwy/detect_compiler_arch.h"
// Only compile this file on Windows; it replaces io.cc. It is easier to check
// this in source code because we support multiple build systems.
#if HWY_OS_WIN
#include <stddef.h>
#include <stdint.h>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ASSERT
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#ifndef VC_EXTRALEAN
#define VC_EXTRALEAN
#endif
#include <Windows.h>
namespace gcpp {
class FileWin : public File {
HANDLE hFile_ = INVALID_HANDLE_VALUE;
public:
FileWin(HANDLE hFile) : hFile_(hFile) {
HWY_ASSERT(hFile != INVALID_HANDLE_VALUE);
}
~FileWin() override {
if (hFile_ != INVALID_HANDLE_VALUE) {
HWY_ASSERT(CloseHandle(hFile_) != 0);
}
}
uint64_t FileSize() const override {
DWORD hi;
const DWORD lo = GetFileSize(hFile_, &hi);
if (lo == INVALID_FILE_SIZE) return 0;
return (static_cast<uint64_t>(hi) << 32) | lo;
}
bool Read(uint64_t offset, uint64_t size, void* to) const override {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
OVERLAPPED overlapped = {0};
// Loop is required because ReadFile[Ex] size argument is 32-bit.
while (size != 0) {
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
const DWORD want =
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
DWORD got;
if (!ReadFile(hFile_, bytes, want, &got, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return false;
}
}
offset += got;
bytes += got;
size -= got;
}
return true; // read everything => success
}
bool Write(const void* from, uint64_t size, uint64_t offset) override {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
OVERLAPPED overlapped = {0};
// Loop is required because WriteFile[Ex] size argument is 32-bit.
while (size != 0) {
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
const DWORD want =
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
DWORD got;
if (!WriteFile(hFile_, bytes, want, &got, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return false;
}
}
offset += got;
bytes += got;
size -= got;
}
return true; // wrote everything => success
}
}; // FileWin
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
const bool is_read = mode[0] != 'w';
const DWORD flags =
FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0);
const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE;
const DWORD share = is_read ? FILE_SHARE_READ : 0;
const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS;
const HANDLE hFile = CreateFileA(filename.path.c_str(), access, share,
nullptr, create, flags, nullptr);
if (hFile == INVALID_HANDLE_VALUE) return std::unique_ptr<File>();
return std::make_unique<FileWin>(hFile);
}
} // namespace gcpp
#endif // HWY_OS_WIN

View File

@ -19,7 +19,6 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <algorithm>
#include <cmath> #include <cmath>
#include <string> #include <string>
@ -77,8 +76,8 @@ class Stats {
void Notify(const float x) { void Notify(const float x) {
++n_; ++n_;
min_ = std::min(min_, x); min_ = HWY_MIN(min_, x);
max_ = std::max(max_, x); max_ = HWY_MAX(max_, x);
product_ *= x; product_ *= x;
@ -119,7 +118,7 @@ class Stats {
// Near zero for normal distributions; if positive on a unimodal distribution, // Near zero for normal distributions; if positive on a unimodal distribution,
// the right tail is fatter. Assumes n_ is large. // the right tail is fatter. Assumes n_ is large.
double SampleSkewness() const { double SampleSkewness() const {
if (std::abs(m2_) < 1E-7) return 0.0; if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5); return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5);
} }
// Corrected for bias (same as Wikipedia and Minitab but not Excel). // Corrected for bias (same as Wikipedia and Minitab but not Excel).
@ -132,7 +131,7 @@ class Stats {
// Near zero for normal distributions; smaller values indicate fewer/smaller // Near zero for normal distributions; smaller values indicate fewer/smaller
// outliers and larger indicates more/larger outliers. Assumes n_ is large. // outliers and larger indicates more/larger outliers. Assumes n_ is large.
double SampleKurtosis() const { double SampleKurtosis() const {
if (std::abs(m2_) < 1E-7) return 0.0; if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
return m4_ * n_ / (m2_ * m2_); return m4_ * n_ / (m2_ * m2_);
} }
// Corrected for bias (same as Wikipedia and Minitab but not Excel). // Corrected for bias (same as Wikipedia and Minitab but not Excel).

View File

@ -39,18 +39,20 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <cctype>
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <random> #include <random>
#include <regex> // NOLINT #include <regex> // NOLINT
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "compression/compress.h" #include "compression/compress.h"
#include "compression/io.h" // Path
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/args.h" // Path
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -914,7 +916,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
} }
} }
// TODO: sink the loop into these functions, i.e. make them matmuls. // TODO: sink the loop into these functions, i.e. make them MatMul.
pool.Run( pool.Run(
0, num_tokens, 0, num_tokens,
[&](const uint64_t token_idx, size_t thread_id) HWY_ATTR { [&](const uint64_t token_idx, size_t thread_id) HWY_ATTR {
@ -1331,7 +1333,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool); new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
std::array<float, TConfig::kNumTensorScales> scales; std::array<float, TConfig::kNumTensorScales> scales;
CacheLoader loader(weights.path.c_str()); CacheLoader loader(weights);
ForEachTensor<TConfig>(nullptr, *c_weights, loader); ForEachTensor<TConfig>(nullptr, *c_weights, loader);
loader.LoadScales(scales.data(), scales.size()); loader.LoadScales(scales.data(), scales.size());
if (!loader.ReadAll(pool)) { if (!loader.ReadAll(pool)) {
@ -1415,7 +1417,7 @@ void CompressWeights(const Path& weights_path,
Compressor compressor(pool); Compressor compressor(pool);
ForEachTensor<TConfig>(weights, *c_weights, compressor); ForEachTensor<TConfig>(weights, *c_weights, compressor);
compressor.AddScales(weights->scales.data(), weights->scales.size()); compressor.AddScales(weights->scales.data(), weights->scales.size());
compressor.WriteAll(pool, compressed_weights_path.path.c_str()); compressor.WriteAll(pool, compressed_weights_path);
weights->layer_ptrs.~LayerPointers<TConfig>(); weights->layer_ptrs.~LayerPointers<TConfig>();
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>(); c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();

View File

@ -22,8 +22,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "compression/io.h" // Path
#include "gemma/configs.h" #include "gemma/configs.h"
#include "util/args.h" // Path
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -35,7 +35,7 @@ using EmbedderInputT = hwy::bfloat16_t;
// Will be called for layers output with: // Will be called for layers output with:
// - position in the tokens sequence // - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm" // - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - ponter to the data array // - pointer to the data array
// - size of the data array // - size of the data array
using LayersOutputT = using LayersOutputT =
std::function<void(int, const std::string&, const float*, size_t)>; std::function<void(int, const std::string&, const float*, size_t)>;

View File

@ -28,37 +28,6 @@
namespace gcpp { namespace gcpp {
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
// strings and supports shortening for display purposes.
struct Path {
Path() {}
explicit Path(const char* p) : path(p) {}
Path& operator=(const char* other) {
path = other;
return *this;
}
std::string Shortened() const {
constexpr size_t kMaxLen = 48;
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
if (path.size() > kMaxLen) {
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
std::string(end(path) - kCutPoint, end(path));
}
if (path.empty()) return "[no path specified]";
return path;
}
// Returns whether the file existed when this was called.
bool Exists() const {
File file;
return file.Open(path.c_str(), "r");
}
std::string path;
};
// Args is a class that provides a ForEach member function which visits each of // Args is a class that provides a ForEach member function which visits each of
// its member variables. ArgsBase provides functions called by Args to // its member variables. ArgsBase provides functions called by Args to
// initialize values to their defaults (passed as an argument to the visitor), // initialize values to their defaults (passed as an argument to the visitor),