mirror of https://github.com/google/gemma.cpp.git
Improved IO abstraction layer
Move to unique_ptr-like File class. Move `if OS_WIN` into wrapper functions. exists -> Exists. PiperOrigin-RevId: 625923056
This commit is contained in:
parent
a939b5fc9f
commit
a8ceb75f43
|
|
@ -55,6 +55,7 @@ cc_library(
|
|||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
deps = [
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ set(SOURCES
|
|||
compression/blob_store.h
|
||||
compression/compress.h
|
||||
compression/compress-inl.h
|
||||
compression/io.cc
|
||||
compression/io.h
|
||||
compression/nuq.h
|
||||
compression/nuq-inl.h
|
||||
compression/sfp.h
|
||||
|
|
|
|||
|
|
@ -169,9 +169,9 @@ inference path of the Gemma model.
|
|||
The sentencepiece library we depend on requires some additional work to build
|
||||
with the Bazel build system. First, it does not export its BUILD file, so we
|
||||
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
|
||||
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to
|
||||
support Abseil as a standalone dependency without third_party/ prefixes, similar
|
||||
to the transforms we apply to Gemma via Copybara.
|
||||
the Abseil library. `bazel/sentencepiece.patch` changes the code to support
|
||||
Abseil as a standalone dependency without third_party/ prefixes, similar to the
|
||||
transforms we apply to Gemma via Copybara.
|
||||
|
||||
## Discord
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ http_archive(
|
|||
strip_prefix = "sentencepiece-0.1.96",
|
||||
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
|
||||
build_file = "@//bazel:sentencepiece.bazel",
|
||||
patches = ["@//bazel:com_google_sentencepiece.patch"],
|
||||
patches = ["@//bazel:sentencepiece.patch"],
|
||||
patch_args = ["-p1"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# Required for referencing bazel:com_google_sentencepiece.patch
|
||||
# Required for referencing bazel:sentencepiece.patch
|
||||
package(
|
||||
default_applicable_licenses = ["//:license"],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
|
|
|
|||
|
|
@ -10,11 +10,23 @@ package(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "io",
|
||||
srcs = ["io.cc"],
|
||||
hdrs = ["io.h"],
|
||||
# Placeholder for io textual_hdrs, do not remove
|
||||
deps = [
|
||||
# Placeholder for io deps, do not remove
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "blob_store",
|
||||
srcs = ["blob_store.cc"],
|
||||
hdrs = ["blob_store.h"],
|
||||
deps = [
|
||||
":io",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -13,88 +13,20 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
|
||||
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
||||
#undef _XOPEN_SOURCE
|
||||
#define _XOPEN_SOURCE 700
|
||||
#endif
|
||||
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
|
||||
#define _POSIX_C_SOURCE 200809
|
||||
#endif
|
||||
|
||||
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
|
||||
#undef _FILE_OFFSET_BITS
|
||||
#define _FILE_OFFSET_BITS 64
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
|
||||
#include <fcntl.h> // open
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||
#include <sys/stat.h> // O_RDONLY
|
||||
#if HWY_OS_WIN
|
||||
#include <fileapi.h>
|
||||
#include <io.h> // read, write, close
|
||||
#else
|
||||
#include <unistd.h> // read, write, close
|
||||
#endif
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
|
||||
namespace {
|
||||
#if HWY_OS_WIN
|
||||
|
||||
// pread is not supported on Windows
|
||||
static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
|
||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
||||
if (file == INVALID_HANDLE_VALUE) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
OVERLAPPED overlapped = {0};
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
|
||||
DWORD bytes_read;
|
||||
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
return bytes_read;
|
||||
}
|
||||
|
||||
// pwrite is not supported on Windows
|
||||
static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) {
|
||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
||||
if (file == INVALID_HANDLE_VALUE) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
OVERLAPPED overlapped = {0};
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
|
||||
DWORD bytes_written;
|
||||
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
return bytes_written;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
hwy::uint128_t MakeKey(const char* string) {
|
||||
|
|
@ -131,61 +63,6 @@ void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
struct IO {
|
||||
// Returns size in bytes or 0.
|
||||
static uint64_t FileSize(const char* filename) {
|
||||
int fd = open(filename, O_RDONLY);
|
||||
if (fd < 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if HWY_OS_WIN
|
||||
const int64_t size = _lseeki64(fd, 0, SEEK_END);
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
if (size < 0) {
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
|
||||
const off_t size = lseek(fd, 0, SEEK_END);
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
if (size == static_cast<off_t>(-1)) {
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
return static_cast<uint64_t>(size);
|
||||
}
|
||||
|
||||
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
// pread seems to be faster than lseek + read when parallelized.
|
||||
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_read <= 0) break;
|
||||
pos += bytes_read;
|
||||
HWY_ASSERT(pos <= size);
|
||||
if (pos == size) break;
|
||||
}
|
||||
return pos == size; // success if managed to read desired size
|
||||
}
|
||||
|
||||
static bool Write(const void* from, uint64_t size, uint64_t offset, int fd) {
|
||||
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
const auto bytes_written =
|
||||
pwrite(fd, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_written <= 0) break;
|
||||
pos += bytes_written;
|
||||
HWY_ASSERT(pos <= size);
|
||||
if (pos == size) break;
|
||||
}
|
||||
return pos == size; // success if managed to write desired size
|
||||
}
|
||||
}; // IO
|
||||
|
||||
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
|
||||
|
||||
// On-disk representation (little-endian).
|
||||
|
|
@ -323,25 +200,11 @@ class BlobStore {
|
|||
#pragma pack(pop)
|
||||
|
||||
BlobError BlobReader::Open(const char* filename) {
|
||||
#if HWY_OS_WIN
|
||||
DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN;
|
||||
HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr,
|
||||
OPEN_EXISTING, flags, nullptr);
|
||||
if (file == INVALID_HANDLE_VALUE) return __LINE__;
|
||||
fd_ = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_RDONLY);
|
||||
#else
|
||||
fd_ = open(filename, O_RDONLY);
|
||||
#endif
|
||||
if (fd_ < 0) return __LINE__;
|
||||
|
||||
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||
// Doubles the readahead window, which seems slightly faster when cached.
|
||||
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||
#endif
|
||||
if (!file_.Open(filename, "r")) return __LINE__;
|
||||
|
||||
// Read first part of header to get actual size.
|
||||
BlobStore bs;
|
||||
if (!IO::Read(fd_, 0, sizeof(bs), &bs)) return __LINE__;
|
||||
if (!file_.Read(0, sizeof(bs), &bs)) return __LINE__;
|
||||
const size_t padded_size = bs.PaddedHeaderSize();
|
||||
HWY_ASSERT(padded_size >= sizeof(bs));
|
||||
|
||||
|
|
@ -353,18 +216,11 @@ BlobError BlobReader::Open(const char* filename) {
|
|||
hwy::CopySameSize(&bs, blob_store_.get());
|
||||
// Read the rest of the header, but not the full file.
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
|
||||
if (!IO::Read(fd_, sizeof(bs), padded_size - sizeof(bs),
|
||||
bytes + sizeof(bs))) {
|
||||
if (!file_.Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
|
||||
return __LINE__;
|
||||
}
|
||||
|
||||
return blob_store_->CheckValidity(IO::FileSize(filename));
|
||||
}
|
||||
|
||||
BlobReader::~BlobReader() {
|
||||
if (fd_ >= 0) {
|
||||
HWY_ASSERT(close(fd_) != -1);
|
||||
}
|
||||
return blob_store_->CheckValidity(file_.FileSize());
|
||||
}
|
||||
|
||||
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
||||
|
|
@ -391,14 +247,14 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
|||
// between consecutive runs.
|
||||
// - memory-mapped I/O is less predictable and adds noise to measurements.
|
||||
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
||||
const int fd = fd_;
|
||||
File* pfile = &file_; // not owned
|
||||
const auto& requests = requests_;
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
// >5x speedup from parallel reads when cached.
|
||||
pool.Run(0, requests.size(),
|
||||
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!IO::Read(fd, requests[i].offset, requests[i].size,
|
||||
requests[i].data)) {
|
||||
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!pfile->Read(requests[i].offset, requests[i].size,
|
||||
requests[i].data)) {
|
||||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
|
|
@ -406,8 +262,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
|
||||
const char* filename) const {
|
||||
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) {
|
||||
HWY_ASSERT(keys_.size() == blobs_.size());
|
||||
|
||||
// Concatenate blobs in memory.
|
||||
|
|
@ -418,26 +273,18 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
|
|||
keys_.data(), blobs_.data(), keys_.size(), bs.get());
|
||||
|
||||
// Create/replace existing file.
|
||||
#if HWY_OS_WIN
|
||||
DWORD flags = FILE_ATTRIBUTE_NORMAL;
|
||||
HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS,
|
||||
flags, nullptr);
|
||||
if (file == INVALID_HANDLE_VALUE) return __LINE__;
|
||||
const int fd = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_WRONLY);
|
||||
#else
|
||||
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
|
||||
#endif
|
||||
if (fd < 0) return __LINE__;
|
||||
File file;
|
||||
if (!file.Open(filename, "w+")) return __LINE__;
|
||||
File* pfile = &file; // not owned
|
||||
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
pool.Run(0, requests.size(),
|
||||
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!IO::Write(requests[i].data, requests[i].size,
|
||||
requests[i].offset, fd)) {
|
||||
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!pfile->Write(requests[i].data, requests[i].size,
|
||||
requests[i].offset)) {
|
||||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
if (err.test_and_set()) return __LINE__;
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::uint128_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -59,7 +60,7 @@ struct BlobIO {
|
|||
class BlobReader {
|
||||
public:
|
||||
BlobReader() { requests_.reserve(500); }
|
||||
~BlobReader();
|
||||
~BlobReader() = default;
|
||||
|
||||
// Opens `filename` and reads its header.
|
||||
BlobError Open(const char* filename);
|
||||
|
|
@ -73,7 +74,7 @@ class BlobReader {
|
|||
private:
|
||||
BlobStorePtr blob_store_; // holds header, not the entire file
|
||||
std::vector<BlobIO> requests_;
|
||||
int fd_ = 0;
|
||||
File file_;
|
||||
};
|
||||
|
||||
class BlobWriter {
|
||||
|
|
@ -84,7 +85,7 @@ class BlobWriter {
|
|||
}
|
||||
|
||||
// Stores all blobs to disk in the given order with padding for alignment.
|
||||
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename) const;
|
||||
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename);
|
||||
|
||||
private:
|
||||
std::vector<hwy::uint128_t> keys_;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,188 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Safe to be first, does not include POSIX headers.
|
||||
#include "compression/io.h"
|
||||
|
||||
// 1.5x slowdown vs. POSIX (200 ms longer startup), hence opt-in.
|
||||
#ifdef GEMMA_IO_GOOGLE
|
||||
#include "compression/io_google.cc"
|
||||
#else
|
||||
|
||||
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
|
||||
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
||||
#undef _XOPEN_SOURCE
|
||||
#define _XOPEN_SOURCE 700
|
||||
#endif
|
||||
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
|
||||
#define _POSIX_C_SOURCE 200809
|
||||
#endif
|
||||
|
||||
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
|
||||
#undef _FILE_OFFSET_BITS
|
||||
#define _FILE_OFFSET_BITS 64
|
||||
|
||||
#include <fcntl.h> // open
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||
#include <sys/stat.h> // O_RDONLY
|
||||
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
#if HWY_OS_WIN
|
||||
#include <fileapi.h>
|
||||
#include <io.h> // read, write, close
|
||||
#else
|
||||
#include <unistd.h> // read, write, close
|
||||
#endif
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Emulate missing POSIX functions.
|
||||
#if HWY_OS_WIN
|
||||
namespace {
|
||||
|
||||
static inline int open(const char* filename, int flags, int mode = 0) {
|
||||
const bool is_read = (flags & _O_RDONLY) != 0;
|
||||
const DWORD win_flags =
|
||||
FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0);
|
||||
const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE;
|
||||
const DWORD share = is_read ? FILE_SHARE_READ : 0;
|
||||
const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS;
|
||||
const HANDLE file =
|
||||
CreateFileA(filename, access, share, nullptr, create, win_flags, nullptr);
|
||||
if (file == INVALID_HANDLE_VALUE) return -1;
|
||||
return _open_osfhandle(reinterpret_cast<intptr_t>(file), flags);
|
||||
}
|
||||
|
||||
static inline off_t lseek(int fd, off_t offset, int whence) {
|
||||
return _lseeki64(fd, offset, whence);
|
||||
}
|
||||
|
||||
static inline int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
|
||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
||||
if (file == INVALID_HANDLE_VALUE) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
OVERLAPPED overlapped = {0};
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
|
||||
DWORD bytes_read;
|
||||
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
return bytes_read;
|
||||
}
|
||||
|
||||
static inline int64_t pwrite(int fd, const void* buf, uint64_t size,
|
||||
uint64_t offset) {
|
||||
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
|
||||
if (file == INVALID_HANDLE_VALUE) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
OVERLAPPED overlapped = {0};
|
||||
overlapped.Offset = offset & 0xFFFFFFFF;
|
||||
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
|
||||
|
||||
DWORD bytes_written;
|
||||
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
|
||||
if (GetLastError() != ERROR_HANDLE_EOF) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
return bytes_written;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
#endif // HWY_OS_WIN
|
||||
|
||||
bool File::Open(const char* filename, const char* mode) {
|
||||
const bool is_read = mode[0] != 'w';
|
||||
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
|
||||
int fd = open(filename, flags, 0644);
|
||||
if (fd < 0) {
|
||||
p_ = 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_read) {
|
||||
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||
// Doubles the readahead window, which seems slightly faster when cached.
|
||||
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||
#endif
|
||||
}
|
||||
|
||||
p_ = static_cast<intptr_t>(fd);
|
||||
return true;
|
||||
}
|
||||
|
||||
void File::Close() {
|
||||
const int fd = static_cast<int>(p_);
|
||||
if (fd > 0) {
|
||||
HWY_ASSERT(close(fd) != -1);
|
||||
p_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t File::FileSize() const {
|
||||
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
|
||||
const int fd = static_cast<int>(p_);
|
||||
const off_t size = lseek(fd, 0, SEEK_END);
|
||||
if (size < 0) {
|
||||
return 0;
|
||||
}
|
||||
return static_cast<uint64_t>(size);
|
||||
}
|
||||
|
||||
bool File::Read(uint64_t offset, uint64_t size, void* to) const {
|
||||
const int fd = static_cast<int>(p_);
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
// pread seems to be faster than lseek + read when parallelized.
|
||||
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_read <= 0) break;
|
||||
pos += bytes_read;
|
||||
HWY_ASSERT(pos <= size);
|
||||
if (pos == size) break;
|
||||
}
|
||||
return pos == size; // success if managed to read desired size
|
||||
}
|
||||
|
||||
bool File::Write(const void* from, uint64_t size, uint64_t offset) {
|
||||
const int fd = static_cast<int>(p_);
|
||||
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
const auto bytes_written =
|
||||
pwrite(fd, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_written <= 0) break;
|
||||
pos += bytes_written;
|
||||
HWY_ASSERT(pos <= size);
|
||||
if (pos == size) break;
|
||||
}
|
||||
return pos == size; // success if managed to write desired size
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // GEMMA_IO_GOOGLE
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// unique_ptr-like interface with RAII, but not (yet) moveable.
|
||||
class File {
|
||||
public:
|
||||
File() = default;
|
||||
~File() { Close(); }
|
||||
File(const File& other) = delete;
|
||||
const File& operator=(const File& other) = delete;
|
||||
|
||||
// Returns false on failure. `mode` is either "r" or "w+".
|
||||
bool Open(const char* filename, const char* mode);
|
||||
|
||||
// No effect if `Open` returned false or `Close` already called.
|
||||
void Close();
|
||||
|
||||
// Returns size in bytes or 0.
|
||||
uint64_t FileSize() const;
|
||||
|
||||
// Returns true if all the requested bytes were read.
|
||||
bool Read(uint64_t offset, uint64_t size, void* to) const;
|
||||
|
||||
// Returns true if all the requested bytes were written.
|
||||
bool Write(const void* from, uint64_t size, uint64_t offset);
|
||||
|
||||
private:
|
||||
intptr_t p_ = 0;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/benchmark/include/benchmark/benchmark.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/app.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
gcpp::LoaderArgs* loader = nullptr;
|
||||
gcpp::InferenceArgs* inference = nullptr;
|
||||
gcpp::Gemma* model = nullptr;
|
||||
hwy::ThreadPool* pool = nullptr;
|
||||
hwy::ThreadPool* inner_pool = nullptr;
|
||||
|
||||
void run_gemma_prompt(const std::string& prompt_string,
|
||||
benchmark::State& state) {
|
||||
std::mt19937 gen;
|
||||
std::vector<int> prompt;
|
||||
|
||||
if (prompt_string.empty()) return;
|
||||
HWY_ASSERT(model->Tokenizer().Encode(prompt_string, &prompt).ok());
|
||||
|
||||
int token_counter = 0;
|
||||
auto stream_token = [&token_counter](int, float) {
|
||||
token_counter++;
|
||||
return true;
|
||||
};
|
||||
|
||||
for (auto s : state) {
|
||||
GenerateGemma(
|
||||
*model, *inference, prompt, /*start_token=*/0, *pool, *inner_pool,
|
||||
stream_token,
|
||||
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
|
||||
}
|
||||
|
||||
state.SetItemsProcessed(token_counter);
|
||||
}
|
||||
|
||||
static void BM_short_prompt(benchmark::State& state) {
|
||||
run_gemma_prompt("What is the capital of Spain?<ctrl23> ", state);
|
||||
}
|
||||
|
||||
static void BM_factuality_prompt(benchmark::State& state) {
|
||||
run_gemma_prompt("How does an inkjet printer work?<ctrl23> ", state);
|
||||
}
|
||||
|
||||
static void BM_creative_prompt(benchmark::State& state) {
|
||||
run_gemma_prompt(
|
||||
"Tell me a story about a magical bunny and their TRS-80.<ctrl23> ",
|
||||
state);
|
||||
}
|
||||
|
||||
static void BM_coding_prompt(benchmark::State& state) {
|
||||
run_gemma_prompt(
|
||||
"Write a python program to generate a fibonacci sequence.<ctrl23> ",
|
||||
state);
|
||||
}
|
||||
|
||||
static void BM_long_coding_prompt(benchmark::State& state) {
|
||||
std::ifstream t("benchmarks.cc", std::ios_base::in);
|
||||
std::stringstream buffer;
|
||||
buffer << t.rdbuf();
|
||||
std::string prompt_string = buffer.str();
|
||||
t.close();
|
||||
|
||||
run_gemma_prompt("Make improvements to the following code:\n " +
|
||||
prompt_string + "<ctrl23> ",
|
||||
state);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
loader = new gcpp::LoaderArgs(argc, argv);
|
||||
inference = new gcpp::InferenceArgs(argc, argv);
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
|
||||
pool = new ::hwy::ThreadPool(app.num_threads);
|
||||
inner_pool = new ::hwy::ThreadPool(0);
|
||||
model = new gcpp::Gemma(*loader, *pool);
|
||||
|
||||
inference->max_tokens = 128;
|
||||
BENCHMARK(BM_short_prompt)
|
||||
->Iterations(3)
|
||||
->Unit(benchmark::kMillisecond)
|
||||
->UseRealTime();
|
||||
|
||||
inference->max_tokens = 256;
|
||||
BENCHMARK(BM_factuality_prompt)
|
||||
->Iterations(3)
|
||||
->Unit(benchmark::kMillisecond)
|
||||
->UseRealTime();
|
||||
|
||||
BENCHMARK(BM_creative_prompt)
|
||||
->Iterations(3)
|
||||
->Unit(benchmark::kMillisecond)
|
||||
->UseRealTime();
|
||||
|
||||
BENCHMARK(BM_coding_prompt)
|
||||
->Iterations(3)
|
||||
->Unit(benchmark::kMillisecond)
|
||||
->UseRealTime();
|
||||
|
||||
inference->max_tokens = 1024;
|
||||
BENCHMARK(BM_long_coding_prompt)
|
||||
->Iterations(3)
|
||||
->Unit(benchmark::kMillisecond)
|
||||
->UseRealTime();
|
||||
|
||||
::benchmark ::RunSpecifiedBenchmarks();
|
||||
::benchmark ::Shutdown();
|
||||
|
||||
delete loader;
|
||||
delete inference;
|
||||
delete model;
|
||||
delete pool;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -55,7 +55,7 @@ struct Args : public ArgsBase<Args> {
|
|||
return "Missing --compressed_weights flag, a file for the compressed "
|
||||
"model.";
|
||||
}
|
||||
if (!weights.exists()) {
|
||||
if (!weights.Exists()) {
|
||||
return "Can't open file specified with --weights flag.";
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@
|
|||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <filesystem> // NOLINT
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
|
|
@ -182,7 +181,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
|||
const Path& checkpoint, hwy::ThreadPool& pool,
|
||||
bool scale_for_compression = false) {
|
||||
PROFILER_ZONE("Startup.LoadWeights");
|
||||
if (!std::filesystem::exists(checkpoint.path)) {
|
||||
if (!checkpoint.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
checkpoint.path.c_str());
|
||||
}
|
||||
|
|
@ -1318,8 +1317,8 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
|||
template <class TConfig>
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
|
||||
const Path& weights, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.LoadCache");
|
||||
if (!std::filesystem::exists(weights.path)) {
|
||||
PROFILER_ZONE("Startup.LoadCompressedWeights");
|
||||
if (!weights.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights.path.c_str());
|
||||
}
|
||||
|
|
@ -1395,7 +1394,7 @@ template <class TConfig>
|
|||
void CompressWeights(const Path& weights_path,
|
||||
const Path& compressed_weights_path,
|
||||
hwy::ThreadPool& pool) {
|
||||
if (!std::filesystem::exists(weights_path.path)) {
|
||||
if (!weights_path.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights_path.path.c_str());
|
||||
}
|
||||
|
|
|
|||
223
gemma/run_csv.cc
223
gemma/run_csv.cc
|
|
@ -1,223 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Command line text interface to gemma.
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h" // ArgsBase
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "third_party/riegeli/bytes/file_reader.h"
|
||||
#include "third_party/riegeli/bytes/file_writer.h"
|
||||
#include "third_party/riegeli/csv/csv_reader.h"
|
||||
#include "third_party/riegeli/csv/csv_writer.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct CsvArgs : public ArgsBase<CsvArgs> {
|
||||
CsvArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
Path input_csv;
|
||||
Path output_csv;
|
||||
int prompt_column;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(input_csv, "input_csv", Path(),
|
||||
"When set, prompts will be read from this CSV.");
|
||||
visitor(output_csv, "output_csv", Path("/tmp/output.csv"),
|
||||
"When --input_csv is set, prompts will be written to this CSV.");
|
||||
visitor(prompt_column, "prompt_column", 0, "Prompt column index");
|
||||
};
|
||||
};
|
||||
|
||||
void FileGemma(gcpp::Gemma& model, InferenceArgs& inference, AppArgs& app,
|
||||
CsvArgs& csv, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const gcpp::AcceptFunc& accept_token) {
|
||||
int abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
int prompt_size{};
|
||||
|
||||
std::mt19937 gen;
|
||||
if (inference.deterministic) {
|
||||
gen.seed(42);
|
||||
} else {
|
||||
std::random_device rd;
|
||||
gen.seed(rd());
|
||||
}
|
||||
|
||||
std::stringstream response_stream;
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&inference, &abs_pos, ¤t_pos, &gen, &prompt_size,
|
||||
tokenizer = &model.Tokenizer(),
|
||||
&response_stream](int token, float) {
|
||||
++abs_pos;
|
||||
++current_pos;
|
||||
if (current_pos < prompt_size) {
|
||||
// pass
|
||||
} else if (token == gcpp::EOS_ID) {
|
||||
if (!inference.multiturn) {
|
||||
abs_pos = 0;
|
||||
if (inference.deterministic) {
|
||||
gen.seed(42);
|
||||
}
|
||||
}
|
||||
// end of stream
|
||||
} else {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode({token}, &token_text).ok());
|
||||
// +1 since position is incremented above
|
||||
if (current_pos == prompt_size + 1) {
|
||||
// first token of response
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
}
|
||||
if (token_text != "\n")
|
||||
response_stream << token_text;
|
||||
else
|
||||
response_stream << "\\n";
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
riegeli::CsvReader csv_reader(
|
||||
riegeli::FileReader(csv.input_csv.path),
|
||||
riegeli::CsvReaderBase::Options().set_comment('#').set_recovery(
|
||||
[](absl::Status status, riegeli::CsvReaderBase& csv_reader) {
|
||||
fprintf(stderr, "Invalid entry: %s", status.message().data());
|
||||
return true;
|
||||
}));
|
||||
|
||||
riegeli::CsvWriter csv_writer(
|
||||
riegeli::FileWriter(csv.output_csv.path),
|
||||
riegeli::CsvWriterBase::Options().set_header({"prompt", "response"}));
|
||||
|
||||
if (!csv_reader.ok()) {
|
||||
HWY_ABORT("Invalid input CSV path %s", csv.input_csv.path.c_str());
|
||||
}
|
||||
|
||||
if (!csv_writer.ok()) {
|
||||
HWY_ABORT("Invalid output CSV path %s", csv.output_csv.path.c_str());
|
||||
}
|
||||
|
||||
while (abs_pos < inference.max_tokens) {
|
||||
std::string prompt_string;
|
||||
std::vector<int> prompt;
|
||||
current_pos = 0;
|
||||
|
||||
std::vector<std::string> record;
|
||||
csv_reader.ReadRecord(record);
|
||||
|
||||
if (record.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
prompt_string = record[csv.prompt_column];
|
||||
fprintf(stdout, "Prompt: %s\n", prompt_string.c_str());
|
||||
|
||||
prompt_string =
|
||||
"<ctrl99>user\n" + prompt_string + "<ctrl100>\n<ctrl99>model\n";
|
||||
if (abs_pos > 0) {
|
||||
// multi-turn dialogue continuation.
|
||||
prompt_string = "<ctrl100>\n" + prompt_string;
|
||||
} else {
|
||||
HWY_DASSERT(abs_pos == 0);
|
||||
if (gcpp::kSystemPrompt) {
|
||||
prompt_string =
|
||||
"<ctrl99>system\nYou are a large language model built by "
|
||||
"Google.<ctrl100>\n" +
|
||||
prompt_string;
|
||||
}
|
||||
}
|
||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
|
||||
prompt_size = prompt.size();
|
||||
|
||||
// generate prompt
|
||||
GenerateGemma(model, inference, prompt, abs_pos, pool, inner_pool,
|
||||
stream_token, accept_token, gen, app.verbosity);
|
||||
|
||||
std::string response_string = response_stream.str();
|
||||
if (!csv_writer.WriteRecord({record[csv.prompt_column], response_string})) {
|
||||
fprintf(stderr, "Failed to write CSV: %s\n",
|
||||
csv_writer.status().message().data());
|
||||
}
|
||||
|
||||
response_stream.str(std::string()); // reset stream
|
||||
response_stream.clear();
|
||||
abs_pos = 0;
|
||||
}
|
||||
|
||||
if (!csv_reader.Close()) {
|
||||
fprintf(stderr, "Failed to close the CSV reader\n");
|
||||
}
|
||||
if (!csv_writer.Close()) {
|
||||
fprintf(stderr, "Failed to close the CSV writer\n");
|
||||
}
|
||||
}
|
||||
|
||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
||||
CsvArgs& csv) {
|
||||
PROFILER_ZONE("Run.misc");
|
||||
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
hwy::ThreadPool pool(app.num_threads);
|
||||
// For many-core, pinning threads to cores helps.
|
||||
if (app.num_threads > 10) {
|
||||
pool.Run(0, pool.NumThreads(),
|
||||
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
|
||||
loader.ModelType(), loader.ModelTraining(), pool);
|
||||
|
||||
if (csv.input_csv.path.empty()) {
|
||||
HWY_ABORT("Need to specify csv file.");
|
||||
}
|
||||
|
||||
FileGemma(model, inference, app, csv, pool, inner_pool,
|
||||
[](int) { return true; });
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
{
|
||||
PROFILER_ZONE("Startup.misc");
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
gcpp::CsvArgs csv(argc, argv);
|
||||
|
||||
if (const char* error = loader.Validate()) {
|
||||
loader.Help();
|
||||
HWY_ABORT("Invalid args: %s", error);
|
||||
}
|
||||
|
||||
gcpp::Run(loader, inference, app, csv);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -133,7 +133,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
if (tokenizer.path.empty()) {
|
||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||
}
|
||||
if (!tokenizer.exists()) {
|
||||
if (!tokenizer.Exists()) {
|
||||
return "Can't open file specified with --tokenizer flag.";
|
||||
}
|
||||
if (!compressed_weights.path.empty()) {
|
||||
|
|
@ -148,7 +148,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the model weights.";
|
||||
}
|
||||
if (!weights.exists()) {
|
||||
if (!weights.Exists()) {
|
||||
return "Can't open file specified with --weights flag.";
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
|||
16
util/args.h
16
util/args.h
|
|
@ -23,16 +23,9 @@
|
|||
#include <algorithm> // std::transform
|
||||
#include <string>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <io.h>
|
||||
#define F_OK 0
|
||||
#define access _access
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
||||
|
|
@ -57,9 +50,10 @@ struct Path {
|
|||
return path;
|
||||
}
|
||||
|
||||
// Beware, TOCTOU.
|
||||
bool exists() const {
|
||||
return (access(path.c_str(), F_OK) == 0);
|
||||
// Returns whether the file existed when this was called.
|
||||
bool Exists() const {
|
||||
File file;
|
||||
return file.Open(path.c_str(), "r");
|
||||
}
|
||||
|
||||
std::string path;
|
||||
|
|
|
|||
Loading…
Reference in New Issue