mirror of https://github.com/google/gemma.cpp.git
Further improve IO, enable multiple backends without -D.
Move Path into io.h and use for opening files. Removes dependency of gemma_lib on args. Separate Windows codepath instead of emulating POSIX functions. Plus lint fixes. PiperOrigin-RevId: 626279004
This commit is contained in:
parent
38f1ea9b80
commit
e9a0caed87
40
BUILD.bazel
40
BUILD.bazel
|
|
@ -51,15 +51,6 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
deps = [
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
|
|
@ -70,10 +61,10 @@ cc_library(
|
|||
"gemma/gemma.h",
|
||||
],
|
||||
deps = [
|
||||
":args",
|
||||
":ops",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:matvec",
|
||||
"@hwy//:nanobenchmark", # timer
|
||||
|
|
@ -83,6 +74,25 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
deps = [
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "app",
|
||||
hdrs = ["util/app.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":gemma_lib",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "gemma_test",
|
||||
srcs = ["gemma/gemma_test.cc"],
|
||||
|
|
@ -102,16 +112,6 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "app",
|
||||
hdrs = ["util/app.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":gemma_lib",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "gemma",
|
||||
srcs = ["gemma/run.cc"],
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ set(SOURCES
|
|||
compression/blob_store.h
|
||||
compression/compress.h
|
||||
compression/compress-inl.h
|
||||
compression/io_win.cc
|
||||
compression/io.cc
|
||||
compression/io.h
|
||||
compression/nuq.h
|
||||
|
|
|
|||
|
|
@ -12,9 +12,11 @@ package(
|
|||
|
||||
cc_library(
|
||||
name = "io",
|
||||
srcs = ["io.cc"],
|
||||
srcs = [
|
||||
"io.cc",
|
||||
# Placeholder for io backend, do not remove
|
||||
],
|
||||
hdrs = ["io.h"],
|
||||
# Placeholder for io textual_hdrs, do not remove
|
||||
deps = [
|
||||
# Placeholder for io deps, do not remove
|
||||
"@hwy//:hwy",
|
||||
|
|
@ -80,12 +82,8 @@ cc_library(
|
|||
|
||||
cc_library(
|
||||
name = "sfp",
|
||||
hdrs = [
|
||||
"sfp.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"sfp-inl.h",
|
||||
],
|
||||
hdrs = ["sfp.h"],
|
||||
textual_hdrs = ["sfp-inl.h"],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
|
|
@ -112,12 +110,8 @@ cc_test(
|
|||
|
||||
cc_library(
|
||||
name = "nuq",
|
||||
hdrs = [
|
||||
"nuq.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"nuq-inl.h",
|
||||
],
|
||||
hdrs = ["nuq.h"],
|
||||
textual_hdrs = ["nuq-inl.h"],
|
||||
deps = [
|
||||
":sfp",
|
||||
"@hwy//:hwy",
|
||||
|
|
@ -158,6 +152,7 @@ cc_library(
|
|||
deps = [
|
||||
":blob_store",
|
||||
":distortion",
|
||||
":io",
|
||||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
|
|
@ -170,9 +165,7 @@ cc_library(
|
|||
# For internal experimentation
|
||||
cc_library(
|
||||
name = "analyze",
|
||||
textual_hdrs = [
|
||||
"analyze.h",
|
||||
],
|
||||
textual_hdrs = ["analyze.h"],
|
||||
deps = [
|
||||
":distortion",
|
||||
":nuq",
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include <stdint.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
|
|
@ -199,12 +200,13 @@ class BlobStore {
|
|||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
BlobError BlobReader::Open(const char* filename) {
|
||||
if (!file_.Open(filename, "r")) return __LINE__;
|
||||
BlobError BlobReader::Open(const Path& filename) {
|
||||
file_ = OpenFileOrNull(filename, "r");
|
||||
if (!file_) return __LINE__;
|
||||
|
||||
// Read first part of header to get actual size.
|
||||
BlobStore bs;
|
||||
if (!file_.Read(0, sizeof(bs), &bs)) return __LINE__;
|
||||
if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__;
|
||||
const size_t padded_size = bs.PaddedHeaderSize();
|
||||
HWY_ASSERT(padded_size >= sizeof(bs));
|
||||
|
||||
|
|
@ -216,11 +218,11 @@ BlobError BlobReader::Open(const char* filename) {
|
|||
hwy::CopySameSize(&bs, blob_store_.get());
|
||||
// Read the rest of the header, but not the full file.
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
|
||||
if (!file_.Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
|
||||
if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
|
||||
return __LINE__;
|
||||
}
|
||||
|
||||
return blob_store_->CheckValidity(file_.FileSize());
|
||||
return blob_store_->CheckValidity(file_->FileSize());
|
||||
}
|
||||
|
||||
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
||||
|
|
@ -247,7 +249,7 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
|||
// between consecutive runs.
|
||||
// - memory-mapped I/O is less predictable and adds noise to measurements.
|
||||
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
||||
File* pfile = &file_; // not owned
|
||||
File* pfile = file_.get(); // not owned
|
||||
const auto& requests = requests_;
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
// >5x speedup from parallel reads when cached.
|
||||
|
|
@ -262,7 +264,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) {
|
||||
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
||||
HWY_ASSERT(keys_.size() == blobs_.size());
|
||||
|
||||
// Concatenate blobs in memory.
|
||||
|
|
@ -273,9 +275,9 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) {
|
|||
keys_.data(), blobs_.data(), keys_.size(), bs.get());
|
||||
|
||||
// Create/replace existing file.
|
||||
File file;
|
||||
if (!file.Open(filename, "w+")) return __LINE__;
|
||||
File* pfile = &file; // not owned
|
||||
std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
|
||||
if (!file) return __LINE__;
|
||||
File* pfile = file.get(); // not owned
|
||||
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
pool.Run(0, requests.size(),
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
|
|
@ -63,7 +64,7 @@ class BlobReader {
|
|||
~BlobReader() = default;
|
||||
|
||||
// Opens `filename` and reads its header.
|
||||
BlobError Open(const char* filename);
|
||||
BlobError Open(const Path& filename);
|
||||
|
||||
// Enqueues read requests if `key` is found and its size matches `size`.
|
||||
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
|
||||
|
|
@ -74,7 +75,7 @@ class BlobReader {
|
|||
private:
|
||||
BlobStorePtr blob_store_; // holds header, not the entire file
|
||||
std::vector<BlobIO> requests_;
|
||||
File file_;
|
||||
std::unique_ptr<File> file_;
|
||||
};
|
||||
|
||||
class BlobWriter {
|
||||
|
|
@ -85,7 +86,7 @@ class BlobWriter {
|
|||
}
|
||||
|
||||
// Stores all blobs to disk in the given order with padding for alignment.
|
||||
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename);
|
||||
BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
|
||||
|
||||
private:
|
||||
std::vector<hwy::uint128_t> keys_;
|
||||
|
|
|
|||
|
|
@ -459,11 +459,11 @@ class Compressor {
|
|||
}
|
||||
}
|
||||
|
||||
void WriteAll(hwy::ThreadPool& pool, const char* blob_filename) {
|
||||
void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
|
||||
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename,
|
||||
err);
|
||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
||||
blob_filename.path.c_str(), err);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h"
|
||||
#include "compression/nuq.h"
|
||||
#include "compression/sfp.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
|
@ -166,13 +167,13 @@ hwy::uint128_t CacheKey(const char* name) {
|
|||
|
||||
class CacheLoader {
|
||||
public:
|
||||
explicit CacheLoader(const char* blob_filename) {
|
||||
explicit CacheLoader(const Path& blob_filename) {
|
||||
err_ = reader_.Open(blob_filename);
|
||||
if (err_ != 0) {
|
||||
fprintf(stderr,
|
||||
"Cached compressed weights does not exist yet (code %d), "
|
||||
"compressing weights and creating file: %s.\n",
|
||||
err_, blob_filename);
|
||||
err_, blob_filename.path.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,12 +14,10 @@
|
|||
// limitations under the License.
|
||||
|
||||
// Safe to be first, does not include POSIX headers.
|
||||
#include "compression/io.h"
|
||||
|
||||
// 1.5x slowdown vs. POSIX (200 ms longer startup), hence opt-in.
|
||||
#ifdef GEMMA_IO_GOOGLE
|
||||
#include "compression/io_google.cc"
|
||||
#else
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to
|
||||
// check this in source code because we support multiple build systems.
|
||||
#if !HWY_OS_WIN
|
||||
|
||||
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
|
||||
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
||||
|
|
@ -39,128 +37,41 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||
#include <sys/stat.h> // O_RDONLY
|
||||
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
#if HWY_OS_WIN
|
||||
#include <fileapi.h>
|
||||
#include <io.h> // read, write, close
|
||||
#else
|
||||
#include <unistd.h> // read, write, close
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Emulate missing POSIX functions.
|
||||
#if HWY_OS_WIN
|
||||
namespace {
|
||||
class FilePosix : public File {
|
||||
int fd_ = 0;
|
||||
|
||||
static inline int open(const char* filename, int flags, int mode = 0) {
|
||||
const bool is_read = (flags & _O_RDONLY) != 0;
|
||||
const DWORD win_flags =
|
||||
FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0);
|
||||
const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE;
|
||||
const DWORD share = is_read ? FILE_SHARE_READ : 0;
|
||||
const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS;
|
||||
const HANDLE file =
|
||||
CreateFileA(filename, access, share, nullptr, create, win_flags, nullptr);
|
||||
if (file == INVALID_HANDLE_VALUE) return -1;
|
||||
return _open_osfhandle(reinterpret_cast<intptr_t>(file), flags);
|
||||
}
|
||||
|
||||
static inline off_t lseek(int fd, off_t offset, int whence) {
|
||||
return _lseeki64(fd, offset, whence);
|
||||
}
|
||||
|
||||
static inline int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
|
||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
||||
if (file == INVALID_HANDLE_VALUE) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
OVERLAPPED overlapped = {0};
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
|
||||
DWORD bytes_read;
|
||||
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return -1;
|
||||
public:
|
||||
explicit FilePosix(int fd) : fd_(fd) { HWY_ASSERT(fd > 0); }
|
||||
~FilePosix() override {
|
||||
if (fd_ != 0) {
|
||||
HWY_ASSERT(close(fd_) != -1);
|
||||
}
|
||||
}
|
||||
|
||||
return bytes_read;
|
||||
}
|
||||
|
||||
static inline int64_t pwrite(int fd, const void* buf, uint64_t size,
|
||||
uint64_t offset) {
|
||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
||||
if (file == INVALID_HANDLE_VALUE) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
OVERLAPPED overlapped = {0};
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
|
||||
DWORD bytes_written;
|
||||
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
return bytes_written;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
#endif // HWY_OS_WIN
|
||||
|
||||
bool File::Open(const char* filename, const char* mode) {
|
||||
const bool is_read = mode[0] != 'w';
|
||||
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
|
||||
int fd = open(filename, flags, 0644);
|
||||
if (fd < 0) {
|
||||
p_ = 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_read) {
|
||||
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||
// Doubles the readahead window, which seems slightly faster when cached.
|
||||
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||
#endif
|
||||
}
|
||||
|
||||
p_ = static_cast<intptr_t>(fd);
|
||||
return true;
|
||||
}
|
||||
|
||||
void File::Close() {
|
||||
const int fd = static_cast<int>(p_);
|
||||
if (fd > 0) {
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
p_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t File::FileSize() const {
|
||||
uint64_t FileSize() const override {
|
||||
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
|
||||
const int fd = static_cast<int>(p_);
|
||||
const off_t size = lseek(fd, 0, SEEK_END);
|
||||
const off_t size = lseek(fd_, 0, SEEK_END);
|
||||
if (size < 0) {
|
||||
return 0;
|
||||
}
|
||||
return static_cast<uint64_t>(size);
|
||||
}
|
||||
|
||||
bool File::Read(uint64_t offset, uint64_t size, void* to) const {
|
||||
const int fd = static_cast<int>(p_);
|
||||
bool Read(uint64_t offset, uint64_t size, void* to) const override {
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
// pread seems to be faster than lseek + read when parallelized.
|
||||
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
|
||||
const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_read <= 0) break;
|
||||
pos += bytes_read;
|
||||
HWY_ASSERT(pos <= size);
|
||||
|
|
@ -169,13 +80,12 @@ bool File::Read(uint64_t offset, uint64_t size, void* to) const {
|
|||
return pos == size; // success if managed to read desired size
|
||||
}
|
||||
|
||||
bool File::Write(const void* from, uint64_t size, uint64_t offset) {
|
||||
const int fd = static_cast<int>(p_);
|
||||
bool Write(const void* from, uint64_t size, uint64_t offset) override {
|
||||
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
const auto bytes_written =
|
||||
pwrite(fd, bytes + pos, size - pos, offset + pos);
|
||||
pwrite(fd_, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_written <= 0) break;
|
||||
pos += bytes_written;
|
||||
HWY_ASSERT(pos <= size);
|
||||
|
|
@ -183,6 +93,29 @@ bool File::Write(const void* from, uint64_t size, uint64_t offset) {
|
|||
}
|
||||
return pos == size; // success if managed to write desired size
|
||||
}
|
||||
}; // FilePosix
|
||||
|
||||
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
|
||||
const Path& filename, const char* mode);
|
||||
|
||||
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
|
||||
std::unique_ptr<File> file; // OpenFileGoogle omitted
|
||||
if (file) return file;
|
||||
|
||||
const bool is_read = mode[0] != 'w';
|
||||
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
|
||||
const int fd = open(filename.path.c_str(), flags, 0644);
|
||||
if (fd < 0) return file;
|
||||
|
||||
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||
if (is_read) {
|
||||
// Doubles the readahead window, which seems slightly faster when cached.
|
||||
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||
}
|
||||
#endif
|
||||
|
||||
return std::make_unique<FilePosix>(fd);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // GEMMA_IO_GOOGLE
|
||||
#endif // !HWY_OS_WIN
|
||||
|
|
|
|||
|
|
@ -16,35 +16,71 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility> // std::move
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// unique_ptr-like interface with RAII, but not (yet) moveable.
|
||||
// Forward-declare to break the circular dependency: OpenFileOrNull returns
|
||||
// File and has a Path argument, and Path::Exists calls OpenFileOrNull. We
|
||||
// prefer to define Exists inline because there are multiple io*.cc files.
|
||||
struct Path;
|
||||
|
||||
// Abstract base class enables multiple I/O backends in the same binary.
|
||||
class File {
|
||||
public:
|
||||
File() = default;
|
||||
~File() { Close(); }
|
||||
virtual ~File() = default;
|
||||
|
||||
// Noncopyable.
|
||||
File(const File& other) = delete;
|
||||
const File& operator=(const File& other) = delete;
|
||||
|
||||
// Returns false on failure. `mode` is either "r" or "w+".
|
||||
bool Open(const char* filename, const char* mode);
|
||||
|
||||
// No effect if `Open` returned false or `Close` already called.
|
||||
void Close();
|
||||
|
||||
// Returns size in bytes or 0.
|
||||
uint64_t FileSize() const;
|
||||
virtual uint64_t FileSize() const = 0;
|
||||
|
||||
// Returns true if all the requested bytes were read.
|
||||
bool Read(uint64_t offset, uint64_t size, void* to) const;
|
||||
virtual bool Read(uint64_t offset, uint64_t size, void* to) const = 0;
|
||||
|
||||
// Returns true if all the requested bytes were written.
|
||||
bool Write(const void* from, uint64_t size, uint64_t offset);
|
||||
virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0;
|
||||
};
|
||||
|
||||
private:
|
||||
intptr_t p_ = 0;
|
||||
// Returns nullptr on failure. `mode` is either "r" or "w+". This is not just
|
||||
// named 'OpenFile' to avoid a conflict with Windows.h #define.
|
||||
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode);
|
||||
|
||||
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
||||
// strings and supports shortening for display purposes.
|
||||
struct Path {
|
||||
Path() {}
|
||||
explicit Path(const char* p) : path(p) {}
|
||||
explicit Path(std::string p) : path(std::move(p)) {}
|
||||
|
||||
Path& operator=(const char* other) {
|
||||
path = other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string Shortened() const {
|
||||
constexpr size_t kMaxLen = 48;
|
||||
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
|
||||
if (path.size() > kMaxLen) {
|
||||
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
|
||||
std::string(end(path) - kCutPoint, end(path));
|
||||
}
|
||||
if (path.empty()) return "[no path specified]";
|
||||
return path;
|
||||
}
|
||||
|
||||
// Returns whether the file existed when this was called.
|
||||
bool Exists() const { return !!OpenFileOrNull(*this, "r"); }
|
||||
|
||||
std::string path;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
// Only compile this file on Windows; it replaces io.cc. It is easier to check
|
||||
// this in source code because we support multiple build systems.
|
||||
#if HWY_OS_WIN
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#ifndef WIN32_LEAN_AND_MEAN
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#endif
|
||||
#ifndef VC_EXTRALEAN
|
||||
#define VC_EXTRALEAN
|
||||
#endif
|
||||
#include <Windows.h>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
class FileWin : public File {
|
||||
HANDLE hFile_ = INVALID_HANDLE_VALUE;
|
||||
|
||||
public:
|
||||
FileWin(HANDLE hFile) : hFile_(hFile) {
|
||||
HWY_ASSERT(hFile != INVALID_HANDLE_VALUE);
|
||||
}
|
||||
~FileWin() override {
|
||||
if (hFile_ != INVALID_HANDLE_VALUE) {
|
||||
HWY_ASSERT(CloseHandle(hFile_) != 0);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t FileSize() const override {
|
||||
DWORD hi;
|
||||
const DWORD lo = GetFileSize(hFile_, &hi);
|
||||
if (lo == INVALID_FILE_SIZE) return 0;
|
||||
return (static_cast<uint64_t>(hi) << 32) | lo;
|
||||
}
|
||||
|
||||
bool Read(uint64_t offset, uint64_t size, void* to) const override {
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
|
||||
OVERLAPPED overlapped = {0};
|
||||
// Loop is required because ReadFile[Ex] size argument is 32-bit.
|
||||
while (size != 0) {
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
const DWORD want =
|
||||
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
|
||||
DWORD got;
|
||||
if (!ReadFile(hFile_, bytes, want, &got, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
offset += got;
|
||||
bytes += got;
|
||||
size -= got;
|
||||
}
|
||||
return true; // read everything => success
|
||||
}
|
||||
|
||||
bool Write(const void* from, uint64_t size, uint64_t offset) override {
|
||||
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
|
||||
OVERLAPPED overlapped = {0};
|
||||
// Loop is required because WriteFile[Ex] size argument is 32-bit.
|
||||
while (size != 0) {
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
const DWORD want =
|
||||
static_cast<DWORD>(HWY_MIN(size, uint64_t{0xFFFFFFFF}));
|
||||
DWORD got;
|
||||
if (!WriteFile(hFile_, bytes, want, &got, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
offset += got;
|
||||
bytes += got;
|
||||
size -= got;
|
||||
}
|
||||
return true; // wrote everything => success
|
||||
}
|
||||
}; // FileWin
|
||||
|
||||
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
|
||||
const bool is_read = mode[0] != 'w';
|
||||
const DWORD flags =
|
||||
FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0);
|
||||
const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE;
|
||||
const DWORD share = is_read ? FILE_SHARE_READ : 0;
|
||||
const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS;
|
||||
const HANDLE hFile = CreateFileA(filename.path.c_str(), access, share,
|
||||
nullptr, create, flags, nullptr);
|
||||
if (hFile == INVALID_HANDLE_VALUE) return std::unique_ptr<File>();
|
||||
return std::make_unique<FileWin>(hFile);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_OS_WIN
|
||||
|
|
@ -19,7 +19,6 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
|
||||
|
|
@ -77,8 +76,8 @@ class Stats {
|
|||
void Notify(const float x) {
|
||||
++n_;
|
||||
|
||||
min_ = std::min(min_, x);
|
||||
max_ = std::max(max_, x);
|
||||
min_ = HWY_MIN(min_, x);
|
||||
max_ = HWY_MAX(max_, x);
|
||||
|
||||
product_ *= x;
|
||||
|
||||
|
|
@ -119,7 +118,7 @@ class Stats {
|
|||
// Near zero for normal distributions; if positive on a unimodal distribution,
|
||||
// the right tail is fatter. Assumes n_ is large.
|
||||
double SampleSkewness() const {
|
||||
if (std::abs(m2_) < 1E-7) return 0.0;
|
||||
if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
|
||||
return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5);
|
||||
}
|
||||
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
||||
|
|
@ -132,7 +131,7 @@ class Stats {
|
|||
// Near zero for normal distributions; smaller values indicate fewer/smaller
|
||||
// outliers and larger indicates more/larger outliers. Assumes n_ is large.
|
||||
double SampleKurtosis() const {
|
||||
if (std::abs(m2_) < 1E-7) return 0.0;
|
||||
if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
|
||||
return m4_ * n_ / (m2_ * m2_);
|
||||
}
|
||||
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
||||
|
|
|
|||
|
|
@ -39,18 +39,20 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/args.h" // Path
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -914,7 +916,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: sink the loop into these functions, i.e. make them matmuls.
|
||||
// TODO: sink the loop into these functions, i.e. make them MatMul.
|
||||
pool.Run(
|
||||
0, num_tokens,
|
||||
[&](const uint64_t token_idx, size_t thread_id) HWY_ATTR {
|
||||
|
|
@ -1331,7 +1333,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
|
|||
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
||||
|
||||
std::array<float, TConfig::kNumTensorScales> scales;
|
||||
CacheLoader loader(weights.path.c_str());
|
||||
CacheLoader loader(weights);
|
||||
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
|
||||
loader.LoadScales(scales.data(), scales.size());
|
||||
if (!loader.ReadAll(pool)) {
|
||||
|
|
@ -1415,7 +1417,7 @@ void CompressWeights(const Path& weights_path,
|
|||
Compressor compressor(pool);
|
||||
ForEachTensor<TConfig>(weights, *c_weights, compressor);
|
||||
compressor.AddScales(weights->scales.data(), weights->scales.size());
|
||||
compressor.WriteAll(pool, compressed_weights_path.path.c_str());
|
||||
compressor.WriteAll(pool, compressed_weights_path);
|
||||
|
||||
weights->layer_ptrs.~LayerPointers<TConfig>();
|
||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/configs.h"
|
||||
#include "util/args.h" // Path
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -35,7 +35,7 @@ using EmbedderInputT = hwy::bfloat16_t;
|
|||
// Will be called for layers output with:
|
||||
// - position in the tokens sequence
|
||||
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
|
||||
// - ponter to the data array
|
||||
// - pointer to the data array
|
||||
// - size of the data array
|
||||
using LayersOutputT =
|
||||
std::function<void(int, const std::string&, const float*, size_t)>;
|
||||
|
|
|
|||
31
util/args.h
31
util/args.h
|
|
@ -28,37 +28,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
||||
// strings and supports shortening for display purposes.
|
||||
struct Path {
|
||||
Path() {}
|
||||
explicit Path(const char* p) : path(p) {}
|
||||
|
||||
Path& operator=(const char* other) {
|
||||
path = other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string Shortened() const {
|
||||
constexpr size_t kMaxLen = 48;
|
||||
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
|
||||
if (path.size() > kMaxLen) {
|
||||
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
|
||||
std::string(end(path) - kCutPoint, end(path));
|
||||
}
|
||||
if (path.empty()) return "[no path specified]";
|
||||
return path;
|
||||
}
|
||||
|
||||
// Returns whether the file existed when this was called.
|
||||
bool Exists() const {
|
||||
File file;
|
||||
return file.Open(path.c_str(), "r");
|
||||
}
|
||||
|
||||
std::string path;
|
||||
};
|
||||
|
||||
// Args is a class that provides a ForEach member function which visits each of
|
||||
// its member variables. ArgsBase provides functions called by Args to
|
||||
// initialize values to their defaults (passed as an argument to the visitor),
|
||||
|
|
|
|||
Loading…
Reference in New Issue