Add mmap support (not yet used)

Also: const-correct ArgsBase,
add assert to mat.h checking element_bytes_
BUILD deps update (:shared provides shared.h, not :sfp)
PiperOrigin-RevId: 746073312
This commit is contained in:
Jan Wassenberg 2025-04-10 10:02:58 -07:00 committed by Copybara-Service
parent 8532da47f7
commit 2e722f14f1
11 changed files with 108 additions and 25 deletions

View File

@ -138,7 +138,7 @@ cc_library(
deps = [ deps = [
":basics", ":basics",
"//compression:fields", "//compression:fields",
"//compression:sfp", "//compression:shared",
"@highway//:hwy", # base.h "@highway//:hwy", # base.h
], ],
) )
@ -159,10 +159,11 @@ cc_test(
deps = [ deps = [
":basics", ":basics",
":common", ":common",
":mat",
":weights", ":weights",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress", "//compression:compress",
"@highway//:hwy", "@highway//:hwy", # aligned_allocator.h
], ],
) )
@ -176,7 +177,7 @@ cc_library(
":common", ":common",
":threading_context", ":threading_context",
"//compression:fields", "//compression:fields",
"//compression:sfp", "//compression:shared",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
], ],
@ -348,7 +349,7 @@ cc_library(
":mat", ":mat",
"//compression:blob_store", "//compression:blob_store",
"//compression:compress", "//compression:compress",
"//compression:io", "//compression:io", # Path
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
"@highway//:stats", "@highway//:stats",
@ -362,8 +363,8 @@ cc_library(
hdrs = ["gemma/tokenizer.h"], hdrs = ["gemma/tokenizer.h"],
deps = [ deps = [
":common", ":common",
"//compression:io", "//compression:io", # Path
"//compression:sfp", "//compression:shared",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
"@com_google_sentencepiece//:sentencepiece_processor", "@com_google_sentencepiece//:sentencepiece_processor",
@ -405,16 +406,17 @@ cc_library(
":allocator", ":allocator",
":basics", ":basics",
":common", ":common",
":ops",
":mat",
":tokenizer",
":kv_cache", ":kv_cache",
":weights", ":mat",
":ops",
":tokenizer",
":threading", ":threading",
":threading_context", ":threading_context",
":weights",
# Placeholder for internal dep, do not remove., # Placeholder for internal dep, do not remove.,
"//compression:blob_store",
"//compression:io", "//compression:io",
"//compression:sfp", "//compression:shared",
"//paligemma:image", "//paligemma:image",
"@highway//:hwy", "@highway//:hwy",
"@highway//:nanobenchmark", # timer "@highway//:nanobenchmark", # timer
@ -445,7 +447,7 @@ cc_library(
":gemma_lib", ":gemma_lib",
":ops", ":ops",
"//compression:io", "//compression:io",
"//compression:sfp", "//compression:shared",
"@highway//:hwy", "@highway//:hwy",
], ],
) )
@ -517,7 +519,7 @@ cc_binary(
":gemma_lib", ":gemma_lib",
":ops", ":ops",
":threading_context", ":threading_context",
"//compression:sfp", "//compression:shared",
"//paligemma:image", "//paligemma:image",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
@ -706,6 +708,7 @@ cc_library(
":mat", ":mat",
":weights", ":weights",
"//compression:compress", "//compression:compress",
"//compression:shared",
"@highway//:hwy", "@highway//:hwy",
"@highway//:thread_pool", "@highway//:thread_pool",
], ],
@ -731,7 +734,7 @@ cc_test(
":threading", ":threading",
":weights", ":weights",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//compression:sfp", "//compression:shared",
"@highway//:thread_pool", "@highway//:thread_pool",
], ],
) )

View File

@ -40,6 +40,7 @@ cc_library(
"//conditions:default": [], "//conditions:default": [],
}), }),
deps = [ deps = [
"//:allocator",
"@highway//:hwy", "@highway//:hwy",
] + FILE_DEPS, ] + FILE_DEPS,
) )
@ -69,6 +70,7 @@ cc_library(
hdrs = ["blob_store.h"], hdrs = ["blob_store.h"],
deps = [ deps = [
":io", ":io",
"//:threading_context",
"@highway//:hwy", "@highway//:hwy",
"@highway//:thread_pool", "@highway//:thread_pool",
], ],
@ -81,7 +83,7 @@ cc_test(
":blob_store", ":blob_store",
":io", ":io",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy", "//:threading_context",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:thread_pool", "@highway//:thread_pool",
], ],
@ -115,21 +117,30 @@ cc_test(
) )
cc_library( cc_library(
name = "sfp", name = "shared",
hdrs = ["shared.h"], hdrs = ["shared.h"],
textual_hdrs = ["sfp-inl.h"],
deps = [ deps = [
"//:basics", "//:basics",
"@highway//:hwy", "@highway//:hwy",
], ],
) )
cc_library(
name = "sfp",
textual_hdrs = ["sfp-inl.h"],
deps = [
":shared",
"//:basics",
"@highway//:hwy",
],
)
cc_library( cc_library(
name = "nuq", name = "nuq",
hdrs = ["shared.h"],
textual_hdrs = ["nuq-inl.h"], textual_hdrs = ["nuq-inl.h"],
deps = [ deps = [
":sfp", ":sfp",
":shared",
"//:basics", "//:basics",
"@highway//:hwy", "@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort", "@highway//hwy/contrib/sort:vqsort",
@ -144,6 +155,7 @@ cc_library(
deps = [ deps = [
":compress", ":compress",
":distortion", ":distortion",
"//:mat",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:thread_pool", "@highway//:thread_pool",
@ -254,6 +266,16 @@ cc_library(
], ],
) )
cc_library(
name = "io_win",
srcs = ["io_win.cc"],
deps = [
":io",
"//:allocator",
"@highway//:hwy",
],
)
cc_binary( cc_binary(
name = "blob_compare", name = "blob_compare",
srcs = ["blob_compare.cc"], srcs = ["blob_compare.cc"],

View File

@ -87,7 +87,7 @@ class PrintVisitor : public VisitorBase {
} }
void operator()(uint64_t& value) override { void operator()(uint64_t& value) override {
fprintf(stderr, "%sU64 %zu\n", indent_.c_str(), value); fprintf(stderr, "%sU64 %zu\n", indent_.c_str(), static_cast<size_t>(value));
} }
void operator()(float& value) override { void operator()(float& value) override {

View File

@ -36,12 +36,16 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE. #include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/types.h>
// Old OSX may require sys/types.h before sys/mman.h.
#include <sys/mman.h> // mmap
#include <sys/stat.h> // O_RDONLY #include <sys/stat.h> // O_RDONLY
#include <unistd.h> // read, write, close #include <unistd.h> // read, write, close
#include <memory> #include <memory>
#include "compression/io.h" #include "compression/io.h"
#include "util/allocator.h"
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
namespace gcpp { namespace gcpp {
@ -93,6 +97,28 @@ class FilePosix : public File {
} }
return pos == size; // success if managed to write desired size return pos == size; // success if managed to write desired size
} }
MapPtr Map() override {
const size_t mapping_size = FileSize();
// No `MAP_POPULATE` because we do not want to wait for I/O, and
// `MAP_NONBLOCK` is not guaranteed. `MAP_HUGETLB` fails. `MAP_SHARED` is
// more efficient than `MAP_PRIVATE`; the main difference is that the former
// will eventually see subsequent changes to the file.
const int flags = MAP_SHARED;
void* mapping =
mmap(nullptr, mapping_size, PROT_READ, flags, fd_, /*offset=*/0);
if (mapping == MAP_FAILED) return MapPtr();
#ifdef MADV_WILLNEED // Missing on some OSX.
// (Maybe) initiate readahead.
madvise(mapping, mapping_size, MADV_WILLNEED);
#endif
return MapPtr(static_cast<const uint8_t*>(mapping),
DeleterFunc2([mapping_size](void* ptr) {
HWY_ASSERT(munmap(ptr, mapping_size) == 0);
}));
}
}; // FilePosix }; // FilePosix
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle( HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(

View File

@ -23,6 +23,7 @@
#include <string> #include <string>
#include <utility> // std::move #include <utility> // std::move
#include "util/allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
namespace gcpp { namespace gcpp {
@ -32,6 +33,8 @@ namespace gcpp {
// prefer to define Exists inline because there are multiple io*.cc files. // prefer to define Exists inline because there are multiple io*.cc files.
struct Path; struct Path;
using MapPtr = AlignedPtr2<const uint8_t[]>;
// Abstract base class enables multiple I/O backends in the same binary. // Abstract base class enables multiple I/O backends in the same binary.
class File { class File {
public: public:
@ -50,6 +53,12 @@ class File {
// Returns true if all the requested bytes were written. // Returns true if all the requested bytes were written.
virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0; virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0;
// Maps the entire file into read-only memory or returns nullptr on failure.
// We do not support offsets because Windows requires them to be a multiple of
// the allocation granularity, which is 64 KiB. Some implementations may fail
// if the file is zero-sized and return a nullptr.
virtual MapPtr Map() = 0;
}; };
// Returns nullptr on failure. `mode` is either "r" or "w+". This is not just // Returns nullptr on failure. `mode` is either "r" or "w+". This is not just
@ -87,6 +96,7 @@ struct Path {
std::string path; std::string path;
}; };
// Aborts on error.
static inline HWY_MAYBE_UNUSED std::string ReadFileToString(const Path& path) { static inline HWY_MAYBE_UNUSED std::string ReadFileToString(const Path& path) {
std::unique_ptr<File> file = OpenFileOrNull(path, "r"); std::unique_ptr<File> file = OpenFileOrNull(path, "r");
if (!file) { if (!file) {

View File

@ -22,6 +22,7 @@
#include <stdint.h> #include <stdint.h>
#include "compression/io.h" #include "compression/io.h"
#include "util/allocator.h"
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
#ifndef WIN32_LEAN_AND_MEAN #ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN
@ -96,6 +97,22 @@ class FileWin : public File {
} }
return true; // wrote everything => success return true; // wrote everything => success
} }
MapPtr Map() override {
if (hFile_ == INVALID_HANDLE_VALUE) return MapPtr();
// Size=0 means the entire file.
HANDLE hMapping =
CreateFileMappingA(hFile_, nullptr, PAGE_READONLY, 0, 0, nullptr);
// Offset zero and size=0 means the entire file.
void* ptr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
if (!ptr) return MapPtr();
return MapPtr(static_cast<const uint8_t*>(ptr),
DeleterFunc2([hMapping](void* ptr) {
HWY_ASSERT(UnmapViewOfFile(ptr));
HWY_ASSERT(CloseHandle(hMapping));
}));
}
}; // FileWin }; // FileWin
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) { std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {

View File

@ -32,7 +32,7 @@ pybind_extension(
deps = [ deps = [
":compression_clif_aux", ":compression_clif_aux",
"@abseil-cpp//absl/types:span", "@abseil-cpp//absl/types:span",
"//compression:sfp", "//compression:shared",
], ],
) )

View File

@ -43,7 +43,7 @@ cc_test(
"//:benchmark_helper", "//:benchmark_helper",
"//:common", "//:common",
"//:gemma_lib", "//:gemma_lib",
"//compression:sfp", "//compression:shared",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
], ],

View File

@ -13,7 +13,7 @@ pybind_extension(
srcs = ["configs.cc"], srcs = ["configs.cc"],
deps = [ deps = [
"//:common", "//:common",
"//compression:sfp", "//compression:shared",
], ],
) )
@ -25,7 +25,7 @@ pybind_extension(
"//:benchmark_helper", "//:benchmark_helper",
"//:gemma_args", "//:gemma_args",
"//:gemma_lib", "//:gemma_lib",
"//compression:sfp", "//compression:shared",
"@highway//:hwy", "@highway//:hwy",
], ],
) )

View File

@ -181,6 +181,10 @@ class ArgsBase {
void ForEach(Visitor& visitor) { void ForEach(Visitor& visitor) {
static_cast<Args*>(this)->ForEach(visitor); static_cast<Args*>(this)->ForEach(visitor);
} }
template <class Visitor>
void ForEach(Visitor& visitor) const {
const_cast<ArgsBase*>(this)->ForEach(visitor);
}
public: public:
// WARNING: cannot call from ctor because the derived ctor has not yet run. // WARNING: cannot call from ctor because the derived ctor has not yet run.
@ -189,12 +193,12 @@ class ArgsBase {
ForEach(visitor); ForEach(visitor);
} }
void Help() { void Help() const {
HelpVisitor visitor; HelpVisitor visitor;
ForEach(visitor); ForEach(visitor);
} }
void Print(int verbosity = 0) { void Print(int verbosity = 0) const {
PrintVisitor visitor(verbosity); PrintVisitor visitor(verbosity);
ForEach(visitor); ForEach(visitor);
} }

View File

@ -112,6 +112,7 @@ class MatPtr : public IFields {
type_ = type; type_ = type;
element_bytes_ = static_cast<uint32_t>(hwy::DivCeil(TypeBits(type), 8)); element_bytes_ = static_cast<uint32_t>(hwy::DivCeil(TypeBits(type), 8));
num_elements_ = static_cast<uint32_t>(ComputeNumElements(type, Extents())); num_elements_ = static_cast<uint32_t>(ComputeNumElements(type, Extents()));
HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16);
} }
bool IsEmpty() const { return rows_ == 0 || cols_ == 0; } bool IsEmpty() const { return rows_ == 0 || cols_ == 0; }