Improved IO abstraction layer

Move to unique_ptr-like File class.
Move `if OS_WIN` into wrapper functions.
exists -> Exists.

PiperOrigin-RevId: 625923056
This commit is contained in:
Jan Wassenberg 2024-04-17 23:14:37 -07:00 committed by Copybara-Service
parent a939b5fc9f
commit a8ceb75f43
17 changed files with 293 additions and 557 deletions

View File

@ -55,6 +55,7 @@ cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
"//compression:io",
"@hwy//:hwy",
],
)

View File

@ -38,6 +38,8 @@ set(SOURCES
compression/blob_store.h
compression/compress.h
compression/compress-inl.h
compression/io.cc
compression/io.h
compression/nuq.h
compression/nuq-inl.h
compression/sfp.h

View File

@ -169,9 +169,9 @@ inference path of the Gemma model.
The sentencepiece library we depend on requires some additional work to build
with the Bazel build system. First, it does not export its BUILD file, so we
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to
support Abseil as a standalone dependency without third_party/ prefixes, similar
to the transforms we apply to Gemma via Copybara.
the Abseil library. `bazel/sentencepiece.patch` changes the code to support
Abseil as a standalone dependency without third_party/ prefixes, similar to the
transforms we apply to Gemma via Copybara.
## Discord

View File

@ -33,7 +33,7 @@ http_archive(
strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//bazel:sentencepiece.bazel",
patches = ["@//bazel:com_google_sentencepiece.patch"],
patches = ["@//bazel:sentencepiece.patch"],
patch_args = ["-p1"],
)

View File

@ -1,4 +1,4 @@
# Required for referencing bazel:com_google_sentencepiece.patch
# Required for referencing bazel:sentencepiece.patch
package(
default_applicable_licenses = ["//:license"],
default_visibility = ["//:__subpackages__"],

View File

@ -10,11 +10,23 @@ package(
],
)
cc_library(
name = "io",
srcs = ["io.cc"],
hdrs = ["io.h"],
# Placeholder for io textual_hdrs, do not remove
deps = [
# Placeholder for io deps, do not remove
"@hwy//:hwy",
],
)
cc_library(
name = "blob_store",
srcs = ["blob_store.cc"],
hdrs = ["blob_store.h"],
deps = [
":io",
"@hwy//:hwy",
"@hwy//:thread_pool",
],

View File

@ -13,88 +13,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 700
#endif
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
#define _POSIX_C_SOURCE 200809
#endif
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
#include "compression/blob_store.h"
#include <fcntl.h> // open
#include <stddef.h>
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY
#if HWY_OS_WIN
#include <fileapi.h>
#include <io.h> // read, write, close
#else
#include <unistd.h> // read, write, close
#endif
#include <atomic>
#include <vector>
#include "compression/io.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h"
namespace {
#if HWY_OS_WIN
// pread is not supported on Windows
static int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_read;
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_read;
}
// pwrite is not supported on Windows
static int64_t pwrite(int fd, const void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_written;
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_written;
}
#endif
} // namespace
namespace gcpp {
hwy::uint128_t MakeKey(const char* string) {
@ -131,61 +63,6 @@ void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
}
} // namespace
struct IO {
// Returns size in bytes or 0.
static uint64_t FileSize(const char* filename) {
int fd = open(filename, O_RDONLY);
if (fd < 0) {
return 0;
}
#if HWY_OS_WIN
const int64_t size = _lseeki64(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size < 0) {
return 0;
}
#else
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const off_t size = lseek(fd, 0, SEEK_END);
HWY_ASSERT(close(fd) != -1);
if (size == static_cast<off_t>(-1)) {
return 0;
}
#endif
return static_cast<uint64_t>(size);
}
static bool Read(int fd, uint64_t offset, uint64_t size, void* to) {
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
static bool Write(const void* from, uint64_t size, uint64_t offset, int fd) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
}
}; // IO
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
// On-disk representation (little-endian).
@ -323,25 +200,11 @@ class BlobStore {
#pragma pack(pop)
BlobError BlobReader::Open(const char* filename) {
#if HWY_OS_WIN
DWORD flags = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN;
HANDLE file = CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, nullptr,
OPEN_EXISTING, flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return __LINE__;
fd_ = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_RDONLY);
#else
fd_ = open(filename, O_RDONLY);
#endif
if (fd_ < 0) return __LINE__;
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
#endif
if (!file_.Open(filename, "r")) return __LINE__;
// Read first part of header to get actual size.
BlobStore bs;
if (!IO::Read(fd_, 0, sizeof(bs), &bs)) return __LINE__;
if (!file_.Read(0, sizeof(bs), &bs)) return __LINE__;
const size_t padded_size = bs.PaddedHeaderSize();
HWY_ASSERT(padded_size >= sizeof(bs));
@ -353,18 +216,11 @@ BlobError BlobReader::Open(const char* filename) {
hwy::CopySameSize(&bs, blob_store_.get());
// Read the rest of the header, but not the full file.
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
if (!IO::Read(fd_, sizeof(bs), padded_size - sizeof(bs),
bytes + sizeof(bs))) {
if (!file_.Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
return __LINE__;
}
return blob_store_->CheckValidity(IO::FileSize(filename));
}
BlobReader::~BlobReader() {
if (fd_ >= 0) {
HWY_ASSERT(close(fd_) != -1);
}
return blob_store_->CheckValidity(file_.FileSize());
}
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
@ -391,13 +247,13 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
// between consecutive runs.
// - memory-mapped I/O is less predictable and adds noise to measurements.
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
const int fd = fd_;
File* pfile = &file_; // not owned
const auto& requests = requests_;
std::atomic_flag err = ATOMIC_FLAG_INIT;
// >5x speedup from parallel reads when cached.
pool.Run(0, requests.size(),
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!IO::Read(fd, requests[i].offset, requests[i].size,
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!pfile->Read(requests[i].offset, requests[i].size,
requests[i].data)) {
err.test_and_set();
}
@ -406,8 +262,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
return 0;
}
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
const char* filename) const {
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const char* filename) {
HWY_ASSERT(keys_.size() == blobs_.size());
// Concatenate blobs in memory.
@ -418,26 +273,18 @@ BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool,
keys_.data(), blobs_.data(), keys_.size(), bs.get());
// Create/replace existing file.
#if HWY_OS_WIN
DWORD flags = FILE_ATTRIBUTE_NORMAL;
HANDLE file = CreateFileA(filename, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS,
flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return __LINE__;
const int fd = _open_osfhandle(reinterpret_cast<intptr_t>(file), _O_WRONLY);
#else
const int fd = open(filename, O_CREAT | O_RDWR | O_TRUNC, 0644);
#endif
if (fd < 0) return __LINE__;
File file;
if (!file.Open(filename, "w+")) return __LINE__;
File* pfile = &file; // not owned
std::atomic_flag err = ATOMIC_FLAG_INIT;
pool.Run(0, requests.size(),
[fd, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!IO::Write(requests[i].data, requests[i].size,
requests[i].offset, fd)) {
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!pfile->Write(requests[i].data, requests[i].size,
requests[i].offset)) {
err.test_and_set();
}
});
HWY_ASSERT(close(fd) != -1);
if (err.test_and_set()) return __LINE__;
return 0;
}

View File

@ -21,6 +21,7 @@
#include <vector>
#include "compression/io.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::uint128_t
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -59,7 +60,7 @@ struct BlobIO {
class BlobReader {
public:
BlobReader() { requests_.reserve(500); }
~BlobReader();
~BlobReader() = default;
// Opens `filename` and reads its header.
BlobError Open(const char* filename);
@ -73,7 +74,7 @@ class BlobReader {
private:
BlobStorePtr blob_store_; // holds header, not the entire file
std::vector<BlobIO> requests_;
int fd_ = 0;
File file_;
};
class BlobWriter {
@ -84,7 +85,7 @@ class BlobWriter {
}
// Stores all blobs to disk in the given order with padding for alignment.
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename) const;
BlobError WriteAll(hwy::ThreadPool& pool, const char* filename);
private:
std::vector<hwy::uint128_t> keys_;

188
compression/io.cc Normal file
View File

@ -0,0 +1,188 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Safe to be first, does not include POSIX headers.
#include "compression/io.h"
// 1.5x slowdown vs. POSIX (200 ms longer startup), hence opt-in.
#ifdef GEMMA_IO_GOOGLE
#include "compression/io_google.cc"
#else
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 700
#endif
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
#define _POSIX_C_SOURCE 200809
#endif
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
#include <fcntl.h> // open
#include <stddef.h>
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/stat.h> // O_RDONLY
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/detect_compiler_arch.h"
#if HWY_OS_WIN
#include <fileapi.h>
#include <io.h> // read, write, close
#else
#include <unistd.h> // read, write, close
#endif
namespace gcpp {
// Emulate missing POSIX functions.
#if HWY_OS_WIN
namespace {
static inline int open(const char* filename, int flags, int mode = 0) {
const bool is_read = (flags & _O_RDONLY) != 0;
const DWORD win_flags =
FILE_ATTRIBUTE_NORMAL | (is_read ? FILE_FLAG_SEQUENTIAL_SCAN : 0);
const DWORD access = is_read ? GENERIC_READ : GENERIC_WRITE;
const DWORD share = is_read ? FILE_SHARE_READ : 0;
const DWORD create = is_read ? OPEN_EXISTING : CREATE_ALWAYS;
const HANDLE file =
CreateFileA(filename, access, share, nullptr, create, win_flags, nullptr);
if (file == INVALID_HANDLE_VALUE) return -1;
return _open_osfhandle(reinterpret_cast<intptr_t>(file), flags);
}
static inline off_t lseek(int fd, off_t offset, int whence) {
return _lseeki64(fd, offset, whence);
}
static inline int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_read;
if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_read;
}
static inline int64_t pwrite(int fd, const void* buf, uint64_t size,
uint64_t offset) {
HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
if (file == INVALID_HANDLE_VALUE) {
return -1;
}
OVERLAPPED overlapped = {0};
overlapped.Offset = offset & 0xFFFFFFFF;
overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
DWORD bytes_written;
if (!WriteFile(file, buf, size, &bytes_written, &overlapped)) {
if (GetLastError() != ERROR_HANDLE_EOF) {
return -1;
}
}
return bytes_written;
}
} // namespace
#endif // HWY_OS_WIN
bool File::Open(const char* filename, const char* mode) {
const bool is_read = mode[0] != 'w';
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
int fd = open(filename, flags, 0644);
if (fd < 0) {
p_ = 0;
return false;
}
if (is_read) {
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
#endif
}
p_ = static_cast<intptr_t>(fd);
return true;
}
void File::Close() {
const int fd = static_cast<int>(p_);
if (fd > 0) {
HWY_ASSERT(close(fd) != -1);
p_ = 0;
}
}
uint64_t File::FileSize() const {
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
const int fd = static_cast<int>(p_);
const off_t size = lseek(fd, 0, SEEK_END);
if (size < 0) {
return 0;
}
return static_cast<uint64_t>(size);
}
bool File::Read(uint64_t offset, uint64_t size, void* to) const {
const int fd = static_cast<int>(p_);
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
uint64_t pos = 0;
for (;;) {
// pread seems to be faster than lseek + read when parallelized.
const auto bytes_read = pread(fd, bytes + pos, size - pos, offset + pos);
if (bytes_read <= 0) break;
pos += bytes_read;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to read desired size
}
bool File::Write(const void* from, uint64_t size, uint64_t offset) {
const int fd = static_cast<int>(p_);
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
uint64_t pos = 0;
for (;;) {
const auto bytes_written =
pwrite(fd, bytes + pos, size - pos, offset + pos);
if (bytes_written <= 0) break;
pos += bytes_written;
HWY_ASSERT(pos <= size);
if (pos == size) break;
}
return pos == size; // success if managed to write desired size
}
} // namespace gcpp
#endif // GEMMA_IO_GOOGLE

52
compression/io.h Normal file
View File

@ -0,0 +1,52 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_
#include <stdint.h>
namespace gcpp {
// unique_ptr-like interface with RAII, but not (yet) moveable.
class File {
public:
File() = default;
~File() { Close(); }
File(const File& other) = delete;
const File& operator=(const File& other) = delete;
// Returns false on failure. `mode` is either "r" or "w+".
bool Open(const char* filename, const char* mode);
// No effect if `Open` returned false or `Close` already called.
void Close();
// Returns size in bytes or 0.
uint64_t FileSize() const;
// Returns true if all the requested bytes were read.
bool Read(uint64_t offset, uint64_t size, void* to) const;
// Returns true if all the requested bytes were written.
bool Write(const void* from, uint64_t size, uint64_t offset);
private:
intptr_t p_ = 0;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_IO_H_

View File

@ -1,137 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include "third_party/benchmark/include/benchmark/benchmark.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
gcpp::LoaderArgs* loader = nullptr;
gcpp::InferenceArgs* inference = nullptr;
gcpp::Gemma* model = nullptr;
hwy::ThreadPool* pool = nullptr;
hwy::ThreadPool* inner_pool = nullptr;
void run_gemma_prompt(const std::string& prompt_string,
benchmark::State& state) {
std::mt19937 gen;
std::vector<int> prompt;
if (prompt_string.empty()) return;
HWY_ASSERT(model->Tokenizer().Encode(prompt_string, &prompt).ok());
int token_counter = 0;
auto stream_token = [&token_counter](int, float) {
token_counter++;
return true;
};
for (auto s : state) {
GenerateGemma(
*model, *inference, prompt, /*start_token=*/0, *pool, *inner_pool,
stream_token,
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
}
state.SetItemsProcessed(token_counter);
}
static void BM_short_prompt(benchmark::State& state) {
run_gemma_prompt("What is the capital of Spain?<ctrl23> ", state);
}
static void BM_factuality_prompt(benchmark::State& state) {
run_gemma_prompt("How does an inkjet printer work?<ctrl23> ", state);
}
static void BM_creative_prompt(benchmark::State& state) {
run_gemma_prompt(
"Tell me a story about a magical bunny and their TRS-80.<ctrl23> ",
state);
}
static void BM_coding_prompt(benchmark::State& state) {
run_gemma_prompt(
"Write a python program to generate a fibonacci sequence.<ctrl23> ",
state);
}
static void BM_long_coding_prompt(benchmark::State& state) {
std::ifstream t("benchmarks.cc", std::ios_base::in);
std::stringstream buffer;
buffer << t.rdbuf();
std::string prompt_string = buffer.str();
t.close();
run_gemma_prompt("Make improvements to the following code:\n " +
prompt_string + "<ctrl23> ",
state);
}
int main(int argc, char** argv) {
loader = new gcpp::LoaderArgs(argc, argv);
inference = new gcpp::InferenceArgs(argc, argv);
gcpp::AppArgs app(argc, argv);
pool = new ::hwy::ThreadPool(app.num_threads);
inner_pool = new ::hwy::ThreadPool(0);
model = new gcpp::Gemma(*loader, *pool);
inference->max_tokens = 128;
BENCHMARK(BM_short_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
inference->max_tokens = 256;
BENCHMARK(BM_factuality_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_creative_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
BENCHMARK(BM_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
inference->max_tokens = 1024;
BENCHMARK(BM_long_coding_prompt)
->Iterations(3)
->Unit(benchmark::kMillisecond)
->UseRealTime();
::benchmark ::RunSpecifiedBenchmarks();
::benchmark ::Shutdown();
delete loader;
delete inference;
delete model;
delete pool;
return 0;
}

View File

@ -55,7 +55,7 @@ struct Args : public ArgsBase<Args> {
return "Missing --compressed_weights flag, a file for the compressed "
"model.";
}
if (!weights.exists()) {
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
return nullptr;

View File

@ -40,7 +40,6 @@
#include <algorithm>
#include <array>
#include <cmath>
#include <filesystem> // NOLINT
#include <iostream>
#include <memory>
#include <random>
@ -182,7 +181,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
const Path& checkpoint, hwy::ThreadPool& pool,
bool scale_for_compression = false) {
PROFILER_ZONE("Startup.LoadWeights");
if (!std::filesystem::exists(checkpoint.path)) {
if (!checkpoint.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
checkpoint.path.c_str());
}
@ -1318,8 +1317,8 @@ void ForEachTensor(const Weights<TConfig>* weights,
template <class TConfig>
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
const Path& weights, hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.LoadCache");
if (!std::filesystem::exists(weights.path)) {
PROFILER_ZONE("Startup.LoadCompressedWeights");
if (!weights.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
weights.path.c_str());
}
@ -1395,7 +1394,7 @@ template <class TConfig>
void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path,
hwy::ThreadPool& pool) {
if (!std::filesystem::exists(weights_path.path)) {
if (!weights_path.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
weights_path.path.c_str());
}

View File

@ -1,223 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line text interface to gemma.
#include <stdio.h>
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h" // ArgsBase
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "third_party/riegeli/bytes/file_reader.h"
#include "third_party/riegeli/bytes/file_writer.h"
#include "third_party/riegeli/csv/csv_reader.h"
#include "third_party/riegeli/csv/csv_writer.h"
namespace gcpp {
struct CsvArgs : public ArgsBase<CsvArgs> {
CsvArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
Path input_csv;
Path output_csv;
int prompt_column;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(input_csv, "input_csv", Path(),
"When set, prompts will be read from this CSV.");
visitor(output_csv, "output_csv", Path("/tmp/output.csv"),
"When --input_csv is set, prompts will be written to this CSV.");
visitor(prompt_column, "prompt_column", 0, "Prompt column index");
};
};
void FileGemma(gcpp::Gemma& model, InferenceArgs& inference, AppArgs& app,
CsvArgs& csv, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const gcpp::AcceptFunc& accept_token) {
int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
int prompt_size{};
std::mt19937 gen;
if (inference.deterministic) {
gen.seed(42);
} else {
std::random_device rd;
gen.seed(rd());
}
std::stringstream response_stream;
// callback function invoked for each generated token.
auto stream_token = [&inference, &abs_pos, &current_pos, &gen, &prompt_size,
tokenizer = &model.Tokenizer(),
&response_stream](int token, float) {
++abs_pos;
++current_pos;
if (current_pos < prompt_size) {
// pass
} else if (token == gcpp::EOS_ID) {
if (!inference.multiturn) {
abs_pos = 0;
if (inference.deterministic) {
gen.seed(42);
}
}
// end of stream
} else {
std::string token_text;
HWY_ASSERT(tokenizer->Decode({token}, &token_text).ok());
// +1 since position is incremented above
if (current_pos == prompt_size + 1) {
// first token of response
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
}
if (token_text != "\n")
response_stream << token_text;
else
response_stream << "\\n";
}
return true;
};
riegeli::CsvReader csv_reader(
riegeli::FileReader(csv.input_csv.path),
riegeli::CsvReaderBase::Options().set_comment('#').set_recovery(
[](absl::Status status, riegeli::CsvReaderBase& csv_reader) {
fprintf(stderr, "Invalid entry: %s", status.message().data());
return true;
}));
riegeli::CsvWriter csv_writer(
riegeli::FileWriter(csv.output_csv.path),
riegeli::CsvWriterBase::Options().set_header({"prompt", "response"}));
if (!csv_reader.ok()) {
HWY_ABORT("Invalid input CSV path %s", csv.input_csv.path.c_str());
}
if (!csv_writer.ok()) {
HWY_ABORT("Invalid output CSV path %s", csv.output_csv.path.c_str());
}
while (abs_pos < inference.max_tokens) {
std::string prompt_string;
std::vector<int> prompt;
current_pos = 0;
std::vector<std::string> record;
csv_reader.ReadRecord(record);
if (record.empty()) {
break;
}
prompt_string = record[csv.prompt_column];
fprintf(stdout, "Prompt: %s\n", prompt_string.c_str());
prompt_string =
"<ctrl99>user\n" + prompt_string + "<ctrl100>\n<ctrl99>model\n";
if (abs_pos > 0) {
// multi-turn dialogue continuation.
prompt_string = "<ctrl100>\n" + prompt_string;
} else {
HWY_DASSERT(abs_pos == 0);
if (gcpp::kSystemPrompt) {
prompt_string =
"<ctrl99>system\nYou are a large language model built by "
"Google.<ctrl100>\n" +
prompt_string;
}
}
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
prompt_size = prompt.size();
// generate prompt
GenerateGemma(model, inference, prompt, abs_pos, pool, inner_pool,
stream_token, accept_token, gen, app.verbosity);
std::string response_string = response_stream.str();
if (!csv_writer.WriteRecord({record[csv.prompt_column], response_string})) {
fprintf(stderr, "Failed to write CSV: %s\n",
csv_writer.status().message().data());
}
response_stream.str(std::string()); // reset stream
response_stream.clear();
abs_pos = 0;
}
if (!csv_reader.Close()) {
fprintf(stderr, "Failed to close the CSV reader\n");
}
if (!csv_writer.Close()) {
fprintf(stderr, "Failed to close the CSV writer\n");
}
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
CsvArgs& csv) {
PROFILER_ZONE("Run.misc");
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
if (app.num_threads > 10) {
pool.Run(0, pool.NumThreads(),
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
}
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
loader.ModelType(), loader.ModelTraining(), pool);
if (csv.input_csv.path.empty()) {
HWY_ABORT("Need to specify csv file.");
}
FileGemma(model, inference, app, csv, pool, inner_pool,
[](int) { return true; });
}
} // namespace gcpp
int main(int argc, char** argv) {
{
PROFILER_ZONE("Startup.misc");
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
gcpp::CsvArgs csv(argc, argv);
if (const char* error = loader.Validate()) {
loader.Help();
HWY_ABORT("Invalid args: %s", error);
}
gcpp::Run(loader, inference, app, csv);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;
}

View File

@ -133,7 +133,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required.";
}
if (!tokenizer.exists()) {
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
if (!compressed_weights.path.empty()) {
@ -148,7 +148,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (weights.path.empty()) {
return "Missing --weights flag, a file for the model weights.";
}
if (!weights.exists()) {
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
return nullptr;

View File

@ -23,16 +23,9 @@
#include <algorithm> // std::transform
#include <string>
#include "compression/io.h"
#include "hwy/base.h" // HWY_ABORT
#if defined(_WIN32)
#include <io.h>
#define F_OK 0
#define access _access
#else
#include <unistd.h>
#endif
namespace gcpp {
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
@ -57,9 +50,10 @@ struct Path {
return path;
}
// Beware, TOCTOU.
bool exists() const {
return (access(path.c_str(), F_OK) == 0);
// Returns whether the file existed when this was called.
bool Exists() const {
File file;
return file.Open(path.c_str(), "r");
}
std::string path;