mirror of https://github.com/google/gemma.cpp.git
Merge branch 'dev' into feature/ISS-60/implement-self-extend
This commit is contained in:
commit
397952f918
13
BUILD.bazel
13
BUILD.bazel
|
|
@ -30,8 +30,11 @@ cc_library(
|
|||
|
||||
cc_library(
|
||||
name = "threading",
|
||||
srcs = ["util/threading.cc"],
|
||||
hdrs = ["util/threading.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
# Placeholder for container detection, do not remove
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//:topology",
|
||||
|
|
@ -173,7 +176,9 @@ cc_test(
|
|||
tags = ["hwy_ops_test"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":ops",
|
||||
":test_util",
|
||||
":threading",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
|
|
@ -233,6 +238,7 @@ cc_library(
|
|||
deps = [
|
||||
":common",
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
|
|
@ -279,6 +285,7 @@ cc_library(
|
|||
":kv_cache",
|
||||
":weights",
|
||||
":threading",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"//paligemma:image",
|
||||
|
|
@ -306,6 +313,7 @@ cc_library(
|
|||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
|
|
@ -316,6 +324,7 @@ cc_library(
|
|||
hdrs = ["util/app.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":threading",
|
||||
|
|
@ -341,8 +350,6 @@ cc_library(
|
|||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//:topology",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -376,6 +383,7 @@ cc_binary(
|
|||
":gemma_lib",
|
||||
":threading",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//compression:sfp",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
|
|
@ -581,6 +589,7 @@ cc_test(
|
|||
},
|
||||
deps = [
|
||||
":backprop",
|
||||
":basics",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":optimizer",
|
||||
|
|
|
|||
|
|
@ -33,6 +33,9 @@ FetchContent_MakeAvailable(sentencepiece)
|
|||
FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(json)
|
||||
|
||||
set(BENCHMARK_ENABLE_TESTING OFF)
|
||||
set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
|
||||
|
||||
FetchContent_Declare(benchmark GIT_REPOSITORY https://github.com/google/benchmark.git GIT_TAG v1.8.2 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(benchmark)
|
||||
|
||||
|
|
@ -42,6 +45,8 @@ set(SOURCES
|
|||
compression/compress.cc
|
||||
compression/compress.h
|
||||
compression/compress-inl.h
|
||||
compression/fields.cc
|
||||
compression/fields.h
|
||||
compression/io_win.cc
|
||||
compression/io.cc
|
||||
compression/io.h
|
||||
|
|
@ -96,6 +101,7 @@ set(SOURCES
|
|||
util/args.h
|
||||
util/basics.h
|
||||
util/test_util.h
|
||||
util/threading.cc
|
||||
util/threading.h
|
||||
)
|
||||
|
||||
|
|
@ -144,8 +150,10 @@ set(GEMMA_TEST_FILES
|
|||
backprop/backward_scalar_test.cc
|
||||
backprop/backward_test.cc
|
||||
backprop/optimize_test.cc
|
||||
compression/blob_store_test.cc
|
||||
compression/compress_test.cc
|
||||
compression/distortion_test.cc
|
||||
compression/fields_test.cc
|
||||
compression/nuq_test.cc
|
||||
compression/sfp_test.cc
|
||||
evals/gemma_test.cc
|
||||
|
|
|
|||
|
|
@ -33,13 +33,15 @@
|
|||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
TEST(OptimizeTest, GradientDescent) {
|
||||
NestedPools pools(1);
|
||||
NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
|
||||
BoundedSlice(0, 1));
|
||||
hwy::ThreadPool& pool = pools.Pool();
|
||||
std::mt19937 gen(42);
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,26 @@ cc_library(
|
|||
] + FILE_DEPS,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "fields",
|
||||
srcs = ["fields.cc"],
|
||||
hdrs = ["fields.h"],
|
||||
deps = [
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "fields_test",
|
||||
srcs = ["fields_test.cc"],
|
||||
deps = [
|
||||
":fields",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "blob_store",
|
||||
srcs = ["blob_store.cc"],
|
||||
|
|
@ -48,6 +68,19 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "blob_store_test",
|
||||
srcs = ["blob_store_test.cc"],
|
||||
deps = [
|
||||
":blob_store",
|
||||
":io",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "distortion",
|
||||
hdrs = [
|
||||
|
|
@ -55,6 +88,7 @@ cc_library(
|
|||
"shared.h",
|
||||
],
|
||||
deps = [
|
||||
"//:basics",
|
||||
"@highway//:hwy",
|
||||
"@highway//:stats",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
|
|
@ -79,6 +113,7 @@ cc_library(
|
|||
hdrs = ["shared.h"],
|
||||
textual_hdrs = ["sfp-inl.h"],
|
||||
deps = [
|
||||
"//:basics",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
|
@ -165,6 +200,7 @@ cc_library(
|
|||
":nuq",
|
||||
":sfp",
|
||||
"//:allocator",
|
||||
"//:basics",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
|
|
|
|||
|
|
@ -275,7 +275,8 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
|||
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!pfile->Read(requests[i].offset, requests[i].size,
|
||||
requests[i].data)) {
|
||||
fprintf(stderr, "Failed to read blob %zu\n", i);
|
||||
fprintf(stderr, "Failed to read blob %zu\n",
|
||||
static_cast<size_t>(i));
|
||||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
|
|
@ -288,6 +289,14 @@ BlobError BlobReader::ReadOne(hwy::uint128_t key, void* data,
|
|||
uint64_t offset;
|
||||
size_t actual_size;
|
||||
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
|
||||
if (actual_size != size) {
|
||||
fprintf(stderr,
|
||||
"Mismatch between expected %d and actual %d KiB size of blob %s. "
|
||||
"Please see README.md on how to update the weights.\n",
|
||||
static_cast<int>(size >> 10), static_cast<int>(actual_size >> 10),
|
||||
StringFromKey(key).c_str());
|
||||
return __LINE__;
|
||||
}
|
||||
if (!file_->Read(offset, actual_size, data)) {
|
||||
return __LINE__;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
#if !HWY_TEST_STANDALONE
|
||||
class BlobStoreTest : public testing::Test {};
|
||||
#endif
|
||||
|
||||
#if !HWY_OS_WIN
|
||||
TEST(BlobStoreTest, TestReadWrite) {
|
||||
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
||||
|
||||
// mkstemp will modify path_str so it holds a newly-created temporary file.
|
||||
char path_str[] = "/tmp/blob_store_test.sbs-XXXXXX";
|
||||
const int fd = mkstemp(path_str);
|
||||
HWY_ASSERT(fd > 0);
|
||||
|
||||
hwy::ThreadPool pool(4);
|
||||
const Path path(path_str);
|
||||
std::array<float, 4> buffer = kOriginalData;
|
||||
|
||||
const hwy::uint128_t keyA = MakeKey("0123456789abcdef");
|
||||
const hwy::uint128_t keyB = MakeKey("q");
|
||||
BlobWriter writer;
|
||||
writer.Add(keyA, "DATA", 5);
|
||||
writer.Add(keyB, buffer.data(), sizeof(buffer));
|
||||
HWY_ASSERT_EQ(writer.WriteAll(pool, path), 0);
|
||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
||||
|
||||
std::fill(buffer.begin(), buffer.end(), 0);
|
||||
BlobReader reader;
|
||||
HWY_ASSERT_EQ(reader.Open(path), 0);
|
||||
HWY_ASSERT_EQ(reader.BlobSize(keyA), 5);
|
||||
HWY_ASSERT_EQ(reader.BlobSize(keyB), sizeof(buffer));
|
||||
|
||||
HWY_ASSERT_EQ(reader.Enqueue(keyB, buffer.data(), sizeof(buffer)), 0);
|
||||
HWY_ASSERT_EQ(reader.ReadAll(pool), 0);
|
||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
||||
|
||||
{
|
||||
std::array<char, 5> buffer;
|
||||
HWY_ASSERT(reader.ReadOne(keyA, buffer.data(), 1) != 0);
|
||||
HWY_ASSERT_EQ(reader.ReadOne(keyA, buffer.data(), 5), 0);
|
||||
HWY_ASSERT_STRING_EQ("DATA", buffer.data());
|
||||
}
|
||||
|
||||
close(fd);
|
||||
unlink(path_str);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
HWY_TEST_MAIN();
|
||||
|
|
@ -33,6 +33,7 @@
|
|||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h"
|
||||
#include "compression/shared.h"
|
||||
#include "util/basics.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "util/allocator.h"
|
||||
#if COMPRESS_STATS
|
||||
|
|
@ -62,7 +63,9 @@ class MatPtr {
|
|||
num_elements_(rows * cols),
|
||||
rows_(rows),
|
||||
cols_(cols),
|
||||
ptr_(nullptr) {}
|
||||
ptr_(nullptr) {
|
||||
stride_ = cols;
|
||||
}
|
||||
// Default is to leave all fields default-initialized.
|
||||
MatPtr() = default;
|
||||
virtual ~MatPtr();
|
||||
|
|
@ -85,7 +88,9 @@ class MatPtr {
|
|||
element_size_(key2.hi),
|
||||
num_elements_(key2.lo),
|
||||
rows_(key3.lo),
|
||||
cols_(key3.hi) {}
|
||||
cols_(key3.hi) {
|
||||
stride_ = cols_;
|
||||
}
|
||||
|
||||
// Adds the contents entry to the table of contents.
|
||||
void AddToToc(std::vector<hwy::uint128_t>& toc) const {
|
||||
|
|
@ -137,6 +142,12 @@ class MatPtr {
|
|||
// Returns the number of columns in the 2-d array (inner dimension).
|
||||
size_t Cols() const { return cols_; }
|
||||
|
||||
Extents2D Extents() const { return Extents2D(rows_, cols_); }
|
||||
|
||||
// Currently same as cols, but may differ in the future. This is the offset by
|
||||
// which to advance pointers to the next row.
|
||||
size_t Stride() const { return stride_; }
|
||||
|
||||
// Decoded elements should be multiplied by this to restore their original
|
||||
// range. This is required because SfpStream can only encode a limited range
|
||||
// of magnitudes.
|
||||
|
|
@ -187,6 +198,8 @@ class MatPtr {
|
|||
// freed. The underlying memory is owned by a subclass or some external class
|
||||
// and must outlive this object.
|
||||
void* ptr_ = nullptr;
|
||||
|
||||
size_t stride_;
|
||||
};
|
||||
|
||||
// MatPtrT adds a single template argument to MatPtr for an explicit type.
|
||||
|
|
@ -288,7 +301,15 @@ decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
|
||||
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
|
||||
mat.scale = m.scale();
|
||||
return mat;
|
||||
}
|
||||
|
||||
// MatStorageT adds the actual data storage to MatPtrT.
|
||||
// TODO: use Extents2D instead of rows and cols.
|
||||
template <typename MatT>
|
||||
class MatStorageT : public MatPtrT<MatT> {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -40,8 +40,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h" // Model
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h" // Model
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/args.h"
|
||||
|
|
@ -73,9 +74,8 @@ struct Args : public ArgsBase<Args> {
|
|||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
ModelTraining model_training;
|
||||
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
||||
model_training)) {
|
||||
model_training_)) {
|
||||
return err;
|
||||
}
|
||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||
|
|
@ -127,10 +127,12 @@ struct Args : public ArgsBase<Args> {
|
|||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
gcpp::Model ModelType() const { return model_type_; }
|
||||
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
|
||||
gcpp::Type WeightType() const { return weight_type_; }
|
||||
|
||||
private:
|
||||
Model model_type_;
|
||||
ModelTraining model_training_;
|
||||
Type weight_type_;
|
||||
};
|
||||
|
||||
|
|
@ -210,10 +212,10 @@ namespace gcpp {
|
|||
|
||||
void Run(Args& args) {
|
||||
hwy::ThreadPool pool(args.num_threads);
|
||||
const Model model_type = args.ModelType();
|
||||
if (model_type == Model::PALIGEMMA_224) {
|
||||
if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) {
|
||||
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
||||
}
|
||||
const Model model_type = args.ModelType();
|
||||
const Type weight_type = args.WeightType();
|
||||
switch (weight_type) {
|
||||
case Type::kF32:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,320 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "compression/fields.h"
|
||||
|
||||
#include <stdarg.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
IFieldsVisitor::~IFieldsVisitor() = default;
|
||||
|
||||
void IFieldsVisitor::NotifyInvalid(const char* fmt, ...) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vfprintf(stderr, fmt, args);
|
||||
va_end(args);
|
||||
|
||||
any_invalid_ = true;
|
||||
}
|
||||
|
||||
class VisitorBase : public IFieldsVisitor {
|
||||
public:
|
||||
VisitorBase() = default;
|
||||
~VisitorBase() override = default;
|
||||
|
||||
// This is the only call site of IFields::VisitFields.
|
||||
void operator()(IFields& fields) override { fields.VisitFields(*this); }
|
||||
|
||||
protected: // Functions shared between ReadVisitor and WriteVisitor:
|
||||
void CheckF32(float value) {
|
||||
if (HWY_UNLIKELY(hwy::ScalarIsInf(value) || hwy::ScalarIsNaN(value))) {
|
||||
NotifyInvalid("Invalid float %g\n", value);
|
||||
}
|
||||
}
|
||||
|
||||
// Return bool to avoid having to check AnyInvalid() after calling.
|
||||
bool CheckStringLength(uint32_t num_u32) {
|
||||
// Disallow long strings for safety, and to prevent them being used for
|
||||
// arbitrary data (we also require them to be ASCII).
|
||||
if (HWY_UNLIKELY(num_u32 > 64)) {
|
||||
NotifyInvalid("String num_u32=%u too large\n", num_u32);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CheckStringU32(uint32_t u32, uint32_t i, uint32_t num_u32) {
|
||||
// Although strings are zero-padded to u32, an entire u32 should not be
|
||||
// zero, and upper bits should not be set (ASCII-only).
|
||||
if (HWY_UNLIKELY(u32 == 0 || (u32 & 0x80808080))) {
|
||||
NotifyInvalid("Invalid characters %x at %u of %u\n", u32, i, num_u32);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class PrintVisitor : public VisitorBase {
|
||||
public:
|
||||
void operator()(uint32_t& value) override {
|
||||
fprintf(stderr, "%sU32 %u\n", indent_.c_str(), value);
|
||||
}
|
||||
|
||||
void operator()(float& value) override {
|
||||
fprintf(stderr, "%sF32 %f\n", indent_.c_str(), value);
|
||||
}
|
||||
|
||||
void operator()(std::string& value) override {
|
||||
fprintf(stderr, "%sStr %s\n", indent_.c_str(), value.c_str());
|
||||
}
|
||||
|
||||
void operator()(IFields& fields) override {
|
||||
fprintf(stderr, "%s%s\n", indent_.c_str(), fields.Name());
|
||||
indent_ += " ";
|
||||
|
||||
VisitorBase::operator()(fields);
|
||||
|
||||
HWY_ASSERT(!indent_.empty());
|
||||
indent_.resize(indent_.size() - 2);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string indent_;
|
||||
};
|
||||
|
||||
class ReadVisitor : public VisitorBase {
|
||||
public:
|
||||
ReadVisitor(const hwy::Span<const uint32_t>& span, size_t pos)
|
||||
: span_(span), result_(pos) {}
|
||||
~ReadVisitor() {
|
||||
HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced.
|
||||
}
|
||||
|
||||
// All data is read through this overload.
|
||||
void operator()(uint32_t& value) override {
|
||||
if (HWY_UNLIKELY(SkipField())) return;
|
||||
|
||||
value = span_[result_.pos++];
|
||||
}
|
||||
|
||||
void operator()(float& value) override {
|
||||
if (HWY_UNLIKELY(SkipField())) return;
|
||||
|
||||
uint32_t u32 = hwy::BitCastScalar<uint32_t>(value);
|
||||
operator()(u32);
|
||||
value = hwy::BitCastScalar<float>(u32);
|
||||
CheckF32(value);
|
||||
}
|
||||
|
||||
void operator()(std::string& value) override {
|
||||
if (HWY_UNLIKELY(SkipField())) return;
|
||||
|
||||
uint32_t num_u32; // not including itself because this..
|
||||
operator()(num_u32); // increments result_.pos for the num_u32 field
|
||||
if (HWY_UNLIKELY(!CheckStringLength(num_u32))) return;
|
||||
|
||||
// Ensure we have that much data.
|
||||
if (HWY_UNLIKELY(result_.pos + num_u32 > end_.back())) {
|
||||
NotifyInvalid("Invalid string: pos %zu + num_u32 %u > end %zu\n",
|
||||
result_.pos, num_u32, span_.size());
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr size_t k4 = sizeof(uint32_t);
|
||||
value.resize(num_u32 * k4);
|
||||
for (uint32_t i = 0; i < num_u32; ++i) {
|
||||
uint32_t u32;
|
||||
operator()(u32);
|
||||
(void)CheckStringU32(u32, i, num_u32);
|
||||
hwy::CopyBytes(&u32, value.data() + i * k4, k4);
|
||||
}
|
||||
|
||||
// Trim 0..3 trailing nulls.
|
||||
const size_t pos = value.find_last_not_of('\0');
|
||||
if (pos != std::string::npos) {
|
||||
value.resize(pos + 1);
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(IFields& fields) override {
|
||||
// Our SkipField requires end_ to be set before reading num_u32, which
|
||||
// determines the actual end, so use an upper bound which is tight if this
|
||||
// IFields is last one in span_.
|
||||
end_.push_back(span_.size());
|
||||
|
||||
if (HWY_UNLIKELY(SkipField())) {
|
||||
end_.pop_back(); // undo `push_back` to keep the stack balanced
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t num_u32; // not including itself because this..
|
||||
operator()(num_u32); // increments result_.pos for the num_u32 field
|
||||
|
||||
// Ensure we have that much data and set end_.
|
||||
if (HWY_UNLIKELY(result_.pos + num_u32 > span_.size())) {
|
||||
NotifyInvalid("Invalid IFields: pos %zu + num_u32 %u > size %zu\n",
|
||||
result_.pos, num_u32, span_.size());
|
||||
return;
|
||||
}
|
||||
end_.back() = result_.pos + num_u32;
|
||||
|
||||
VisitorBase::operator()(fields);
|
||||
|
||||
HWY_ASSERT(!end_.empty() && result_.pos <= end_.back());
|
||||
// Count extra, which indicates old code and new data.
|
||||
result_.extra_u32 += end_.back() - result_.pos;
|
||||
end_.pop_back();
|
||||
}
|
||||
|
||||
// Override because ReadVisitor also does bounds checking.
|
||||
bool SkipField() override {
|
||||
// If invalid, all bets are off and we don't count missing fields.
|
||||
if (HWY_UNLIKELY(AnyInvalid())) return true;
|
||||
|
||||
// Reaching the end of the stored size, or the span, is not invalid -
|
||||
// it happens when we read old data with new code.
|
||||
if (HWY_UNLIKELY(result_.pos >= end_.back())) {
|
||||
result_.missing_fields++;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Override so that operator()(std::vector<T>&) resizes the vector.
|
||||
bool IsReading() const override { return true; }
|
||||
|
||||
IFields::ReadResult Result() {
|
||||
if (HWY_UNLIKELY(AnyInvalid())) result_.pos = 0;
|
||||
return result_;
|
||||
}
|
||||
|
||||
private:
|
||||
const hwy::Span<const uint32_t> span_;
|
||||
IFields::ReadResult result_;
|
||||
// Stack of end positions of nested IFields. Updated in operator()(IFields&),
|
||||
// but read in SkipField.
|
||||
std::vector<uint32_t> end_;
|
||||
};
|
||||
|
||||
class WriteVisitor : public VisitorBase {
|
||||
public:
|
||||
WriteVisitor(std::vector<uint32_t>& storage) : storage_(storage) {}
|
||||
|
||||
// Note: while writing, only string lengths/characters can trigger AnyInvalid,
|
||||
// so we don't have to check SkipField.
|
||||
|
||||
void operator()(uint32_t& value) override { storage_.push_back(value); }
|
||||
|
||||
void operator()(float& value) override {
|
||||
storage_.push_back(hwy::BitCastScalar<uint32_t>(value));
|
||||
CheckF32(value);
|
||||
}
|
||||
|
||||
void operator()(std::string& value) override {
|
||||
constexpr size_t k4 = sizeof(uint32_t);
|
||||
|
||||
// Write length.
|
||||
uint32_t num_u32 = hwy::DivCeil(value.size(), k4);
|
||||
if (HWY_UNLIKELY(!CheckStringLength(num_u32))) return;
|
||||
operator()(num_u32); // always valid
|
||||
|
||||
// Copy whole uint32_t.
|
||||
const size_t num_whole_u32 = value.size() / k4;
|
||||
for (uint32_t i = 0; i < num_whole_u32; ++i) {
|
||||
uint32_t u32 = 0;
|
||||
hwy::CopyBytes(value.data() + i * k4, &u32, k4);
|
||||
if (HWY_UNLIKELY(!CheckStringU32(u32, i, num_u32))) return;
|
||||
storage_.push_back(u32);
|
||||
}
|
||||
|
||||
// Read remaining bytes into least-significant bits of u32.
|
||||
const size_t remainder = value.size() - num_whole_u32 * k4;
|
||||
if (remainder != 0) {
|
||||
HWY_DASSERT(remainder < k4);
|
||||
uint32_t u32 = 0;
|
||||
for (size_t i = 0; i < remainder; ++i) {
|
||||
const char c = value[num_whole_u32 * k4 + i];
|
||||
const uint32_t next = static_cast<uint32_t>(static_cast<uint8_t>(c));
|
||||
u32 += next << (i * 8);
|
||||
}
|
||||
if (HWY_UNLIKELY(!CheckStringU32(u32, num_whole_u32, num_u32))) return;
|
||||
storage_.push_back(u32);
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(IFields& fields) override {
|
||||
const size_t pos_before_size = storage_.size();
|
||||
storage_.push_back(0); // placeholder, updated below
|
||||
|
||||
VisitorBase::operator()(fields);
|
||||
|
||||
HWY_ASSERT(storage_[pos_before_size] == 0);
|
||||
// Number of u32 written, including the one storing that number.
|
||||
const uint32_t num_u32 = storage_.size() - pos_before_size;
|
||||
HWY_ASSERT(num_u32 != 0); // at least one due to push_back above
|
||||
// Store the payload size, not including the num_u32 field itself, because
|
||||
// that is more convenient for ReadVisitor.
|
||||
storage_[pos_before_size] = num_u32 - 1;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<uint32_t>& storage_;
|
||||
};
|
||||
|
||||
IFields::~IFields() = default;
|
||||
|
||||
void IFields::Print() const {
|
||||
PrintVisitor visitor;
|
||||
// VisitFields is non-const. It is safe to cast because PrintVisitor does not
|
||||
// modify the fields.
|
||||
visitor(*const_cast<IFields*>(this));
|
||||
}
|
||||
|
||||
IFields::ReadResult IFields::Read(const hwy::Span<const uint32_t>& span,
|
||||
size_t pos) {
|
||||
ReadVisitor visitor(span, pos);
|
||||
visitor(*this);
|
||||
return visitor.Result();
|
||||
}
|
||||
|
||||
bool IFields::AppendTo(std::vector<uint32_t>& storage) const {
|
||||
// VisitFields is non-const. It is safe to cast because WriteVisitor does not
|
||||
// modify the fields.
|
||||
IFields& fields = *const_cast<IFields*>(this);
|
||||
|
||||
// Reduce allocations, but not in debug builds so we notice any iterator
|
||||
// invalidation bugs.
|
||||
if constexpr (!HWY_IS_DEBUG_BUILD) {
|
||||
storage.reserve(storage.size() + 256);
|
||||
}
|
||||
|
||||
WriteVisitor visitor(storage);
|
||||
visitor(fields);
|
||||
return !visitor.AnyInvalid();
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_
|
||||
|
||||
// Simple serialization/deserialization for user-defined classes, inspired by
|
||||
// BSD-licensed code Copyright (c) the JPEG XL Project Authors:
|
||||
// https://github.com/libjxl/libjxl, lib/jxl/fields.h.
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Design goals:
|
||||
// - self-contained to simplify installing/building, no separate compiler (rules
|
||||
// out Protocol Buffers, FlatBuffers, Cap'n Proto, Apache Thrift).
|
||||
// - simplicity: small codebase without JIT (rules out Apache Fury and bitsery).
|
||||
// - old code can read new data, and new code can read old data (rules out yas
|
||||
// and msgpack). This avoids rewriting weights when we add a new field.
|
||||
// - no user-specified versions (rules out cereal) nor field names (rules out
|
||||
// JSON and GGUF). These are error-prone; users should just be able to append
|
||||
// new fields.
|
||||
//
|
||||
// Non-goals:
|
||||
// - anything better than reasonable encoding size and decode speed: we only
|
||||
// anticipate ~KiB of data, alongside ~GiB of separately compressed weights.
|
||||
// - features such as maps, interfaces, random access, and optional/deleted
|
||||
// fields: not required for the intended use case of `ModelConfig`.
|
||||
// - support any other languages than C++ and Python (for the exporter).
|
||||
|
||||
class IFields; // breaks circular dependency
|
||||
|
||||
// Visitors are internal-only, but their base class is visible to user code
|
||||
// because their `IFields::VisitFields` calls `visitor.operator()`.
|
||||
//
|
||||
// Supported field types `T`: `uint32_t`, `float`, `std::string`, classes
|
||||
// derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
|
||||
class IFieldsVisitor {
|
||||
public:
|
||||
virtual ~IFieldsVisitor();
|
||||
|
||||
// Indicates whether NotifyInvalid was called for any field. Once set, this is
|
||||
// sticky for all IFields visited by this visitor.
|
||||
bool AnyInvalid() const { return any_invalid_; }
|
||||
|
||||
// None of these fail directly, but they call NotifyInvalid() if any value
|
||||
// is out of range. A single generic/overloaded function is required to
|
||||
// support `std::vector<T>`.
|
||||
virtual void operator()(uint32_t& value) = 0;
|
||||
virtual void operator()(float& value) = 0;
|
||||
virtual void operator()(std::string& value) = 0;
|
||||
virtual void operator()(IFields& fields) = 0; // recurse into nested fields
|
||||
|
||||
// bool and enum fields are actually stored as uint32_t.
|
||||
void operator()(bool& value) {
|
||||
if (HWY_UNLIKELY(SkipField())) return;
|
||||
|
||||
uint32_t u32 = value ? 1 : 0;
|
||||
operator()(u32);
|
||||
if (HWY_UNLIKELY(u32 > 1)) {
|
||||
return NotifyInvalid("Invalid bool %u\n", u32);
|
||||
}
|
||||
value = (u32 == 1);
|
||||
}
|
||||
|
||||
template <typename EnumT, hwy::EnableIf<std::is_enum_v<EnumT>>* = nullptr>
|
||||
void operator()(EnumT& value) {
|
||||
if (HWY_UNLIKELY(SkipField())) return;
|
||||
|
||||
uint32_t u32 = static_cast<uint32_t>(value);
|
||||
operator()(u32);
|
||||
if (HWY_UNLIKELY(!EnumValid(static_cast<EnumT>(u32)))) {
|
||||
return NotifyInvalid("Invalid enum %u\n");
|
||||
}
|
||||
value = static_cast<EnumT>(u32);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void operator()(std::vector<T>& value) {
|
||||
if (HWY_UNLIKELY(SkipField())) return;
|
||||
|
||||
uint32_t num = static_cast<uint32_t>(value.size());
|
||||
operator()(num);
|
||||
if (HWY_UNLIKELY(num > 64 * 1024)) {
|
||||
return NotifyInvalid("Vector too long %u\n", num);
|
||||
}
|
||||
|
||||
if (IsReading()) {
|
||||
value.resize(num);
|
||||
}
|
||||
for (size_t i = 0; i < value.size(); ++i) {
|
||||
operator()(value[i]);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// Prints a message and causes subsequent AnyInvalid() to return true.
|
||||
void NotifyInvalid(const char* fmt, ...);
|
||||
|
||||
// Must check this before modifying any field, and if it returns true,
|
||||
// avoid doing so. This is important for strings and vectors in the
|
||||
// "new code, old data" test: resizing them may destroy their contents.
|
||||
virtual bool SkipField() { return AnyInvalid(); }
|
||||
// For operator()(std::vector&).
|
||||
virtual bool IsReading() const { return false; }
|
||||
|
||||
private:
|
||||
bool any_invalid_ = false;
|
||||
};
|
||||
|
||||
// Abstract base class for user-defined serializable classes, which are
|
||||
// forward- and backward compatible collection of fields (members). This means
|
||||
// old code can safely read new data, and new code can still handle old data.
|
||||
//
|
||||
// Fields are written in the unchanging order established by the user-defined
|
||||
// `VisitFields`; any new fields must be visited after all existing fields in
|
||||
// the same `IFields`. We encode each into `uint32_t` storage for simplicity.
|
||||
//
|
||||
// HOWTO:
|
||||
// - basic usage: define a struct with member variables ("fields") and their
|
||||
// initializers, e.g. `uint32_t field = 0;`. Then define a
|
||||
// `void VisitFields(IFieldsVisitor& v)` member function that calls
|
||||
// `v(field);` etc. for each field, and a `const char* Name()` function used
|
||||
// as a caption when printing.
|
||||
//
|
||||
// - enum fields: define `enum class EnumT` and `bool EnumValid(EnumT)`, then
|
||||
// call `v(field);` as usual. Note that `EnumT` is not extendable insofar as
|
||||
// `EnumValid` returns false for values beyond the initially known ones. You
|
||||
// can add placeholders, which requires user code to know how to handle them,
|
||||
// or later add new fields including enums to override the first enum.
|
||||
struct IFields {
|
||||
virtual ~IFields();
|
||||
|
||||
// User-defined caption used during Print().
|
||||
virtual const char* Name() const = 0;
|
||||
|
||||
// User-defined, called by IFieldsVisitor::operator()(IFields&).
|
||||
virtual void VisitFields(IFieldsVisitor& visitor) = 0;
|
||||
|
||||
// Prints name and fields to stderr.
|
||||
void Print() const;
|
||||
|
||||
struct ReadResult {
|
||||
ReadResult(size_t pos) : pos(pos), missing_fields(0), extra_u32(0) {}
|
||||
|
||||
// Where to resume reading in the next Read() call, or 0 if there was an
|
||||
// unrecoverable error: any field has an invalid value, or the span is
|
||||
// shorter than the data says it should be. If so, do not use the fields nor
|
||||
// continue reading.
|
||||
size_t pos;
|
||||
// From the perspective of VisitFields, how many more fields would have
|
||||
// been read beyond the stored size. If non-zero, the data is older than
|
||||
// the code, but valid, and extra_u32 should be zero.
|
||||
uint32_t missing_fields;
|
||||
// How many extra u32 are in the stored size, vs. what we actually read as
|
||||
// requested by VisitFields. If non-zero,, the data is newer than the code,
|
||||
// but valid, and missing_fields should be zero.
|
||||
uint32_t extra_u32;
|
||||
};
|
||||
|
||||
// Reads fields starting at `span[pos]`.
|
||||
ReadResult Read(const hwy::Span<const uint32_t>& span, size_t pos);
|
||||
|
||||
// Returns false if there was an unrecoverable error, typically because a
|
||||
// field has an invalid value. If so, `storage` is undefined.
|
||||
bool AppendTo(std::vector<uint32_t>& storage) const;
|
||||
|
||||
// Convenience wrapper for AppendTo when we only write once.
|
||||
std::vector<uint32_t> Write() const {
|
||||
std::vector<uint32_t> storage;
|
||||
if (!AppendTo(storage)) storage.clear();
|
||||
return storage;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_
|
||||
|
|
@ -0,0 +1,354 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "compression/fields.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
#if !HWY_TEST_STANDALONE
|
||||
class FieldsTest : public testing::Test {};
|
||||
#endif
|
||||
|
||||
void MaybePrint(const IFields& fields) {
|
||||
if (HWY_IS_DEBUG_BUILD) {
|
||||
fields.Print();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckVectorEqual(const std::vector<T>& a, const std::vector<T>& b) {
|
||||
EXPECT_EQ(a.size(), b.size());
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
if constexpr (std::is_base_of_v<IFields, T>) {
|
||||
a[i].CheckEqual(b[i]);
|
||||
} else {
|
||||
EXPECT_EQ(a[i], b[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum class Enum : uint32_t {
|
||||
k1 = 1,
|
||||
k3 = 3,
|
||||
k8 = 8,
|
||||
};
|
||||
HWY_MAYBE_UNUSED bool EnumValid(Enum e) {
|
||||
return e == Enum::k1 || e == Enum::k3 || e == Enum::k8;
|
||||
}
|
||||
|
||||
// Contains all supported types except IFields and std::vector<IFields>.
|
||||
struct Nested : public IFields {
|
||||
Nested() : nested_u32(0) {} // for std::vector
|
||||
explicit Nested(uint32_t u32) : nested_u32(u32) {}
|
||||
|
||||
const char* Name() const override { return "Nested"; }
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(nested_u32);
|
||||
visitor(nested_bool);
|
||||
visitor(nested_vector);
|
||||
visitor(nested_enum);
|
||||
visitor(nested_str);
|
||||
visitor(nested_f);
|
||||
}
|
||||
|
||||
void CheckEqual(const Nested& n) const {
|
||||
EXPECT_EQ(nested_u32, n.nested_u32);
|
||||
EXPECT_EQ(nested_bool, n.nested_bool);
|
||||
CheckVectorEqual(nested_vector, n.nested_vector);
|
||||
EXPECT_EQ(nested_enum, n.nested_enum);
|
||||
EXPECT_EQ(nested_str, n.nested_str);
|
||||
EXPECT_EQ(nested_f, n.nested_f);
|
||||
}
|
||||
|
||||
uint32_t nested_u32; // set in ctor
|
||||
bool nested_bool = true;
|
||||
std::vector<uint32_t> nested_vector = {1, 2, 3};
|
||||
Enum nested_enum = Enum::k1;
|
||||
std::string nested_str = "nested";
|
||||
float nested_f = 1.125f;
|
||||
};
|
||||
|
||||
// Contains all supported types.
|
||||
struct OldFields : public IFields {
|
||||
const char* Name() const override { return "OldFields"; }
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(old_str);
|
||||
visitor(old_nested);
|
||||
visitor(old1);
|
||||
visitor(old_vec_str);
|
||||
visitor(old_vec_nested);
|
||||
visitor(old_f);
|
||||
visitor(old_enum);
|
||||
visitor(old_bool);
|
||||
}
|
||||
|
||||
// Template allows comparing with NewFields.
|
||||
template <typename Other>
|
||||
void CheckEqual(const Other& n) const {
|
||||
EXPECT_EQ(old_str, n.old_str);
|
||||
old_nested.CheckEqual(n.old_nested);
|
||||
EXPECT_EQ(old1, n.old1);
|
||||
CheckVectorEqual(old_vec_str, n.old_vec_str);
|
||||
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
|
||||
EXPECT_EQ(old_f, n.old_f);
|
||||
EXPECT_EQ(old_enum, n.old_enum);
|
||||
EXPECT_EQ(old_bool, n.old_bool);
|
||||
}
|
||||
|
||||
std::string old_str = "old";
|
||||
Nested old_nested = Nested(0);
|
||||
uint32_t old1 = 1;
|
||||
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
||||
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
||||
float old_f = 1.125f;
|
||||
Enum old_enum = Enum::k1;
|
||||
bool old_bool = true;
|
||||
}; // OldFields
|
||||
|
||||
// Simulates adding new fields of all types to an existing struct.
|
||||
struct NewFields : public IFields {
|
||||
const char* Name() const override { return "NewFields"; }
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(old_str);
|
||||
visitor(old_nested);
|
||||
visitor(old1);
|
||||
visitor(old_vec_str);
|
||||
visitor(old_vec_nested);
|
||||
visitor(old_f);
|
||||
visitor(old_enum);
|
||||
visitor(old_bool);
|
||||
|
||||
// Change order of field types relative to OldFields to ensure that works.
|
||||
visitor(new_nested);
|
||||
visitor(new_bool);
|
||||
visitor(new_vec_nested);
|
||||
visitor(new_f);
|
||||
visitor(new_vec_str);
|
||||
visitor(new_enum);
|
||||
visitor(new2);
|
||||
visitor(new_str);
|
||||
}
|
||||
|
||||
void CheckEqual(const NewFields& n) const {
|
||||
EXPECT_EQ(old_str, n.old_str);
|
||||
old_nested.CheckEqual(n.old_nested);
|
||||
EXPECT_EQ(old1, n.old1);
|
||||
CheckVectorEqual(old_vec_str, n.old_vec_str);
|
||||
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
|
||||
EXPECT_EQ(old_f, n.old_f);
|
||||
EXPECT_EQ(old_enum, n.old_enum);
|
||||
EXPECT_EQ(old_bool, n.old_bool);
|
||||
|
||||
new_nested.CheckEqual(n.new_nested);
|
||||
EXPECT_EQ(new_bool, n.new_bool);
|
||||
CheckVectorEqual(new_vec_nested, n.new_vec_nested);
|
||||
EXPECT_EQ(new_f, n.new_f);
|
||||
CheckVectorEqual(new_vec_str, n.new_vec_str);
|
||||
EXPECT_EQ(new_enum, n.new_enum);
|
||||
EXPECT_EQ(new2, n.new2);
|
||||
EXPECT_EQ(new_str, n.new_str);
|
||||
}
|
||||
|
||||
// Copied from OldFields to match the use case of adding new fields. If we
|
||||
// write an OldFields member, that would change the layout due to its size.
|
||||
std::string old_str = "old";
|
||||
Nested old_nested = Nested(0);
|
||||
uint32_t old1 = 1;
|
||||
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
||||
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
||||
float old_f = 1.125f;
|
||||
Enum old_enum = Enum::k1;
|
||||
bool old_bool = true;
|
||||
|
||||
Nested new_nested = Nested(999);
|
||||
bool new_bool = false;
|
||||
std::vector<Nested> new_vec_nested = {Nested(2), Nested(3)};
|
||||
float new_f = -2.0f;
|
||||
std::vector<std::string> new_vec_str = {"AB", std::string(), "56789"};
|
||||
Enum new_enum = Enum::k3;
|
||||
uint32_t new2 = 2;
|
||||
std::string new_str = std::string(); // empty is allowed
|
||||
}; // NewFields
|
||||
|
||||
// Changes all fields to non-default values.
|
||||
NewFields ModifiedNewFields() {
|
||||
NewFields n;
|
||||
n.old_str = "old2";
|
||||
n.old_nested = Nested(5);
|
||||
n.old1 = 11;
|
||||
n.old_vec_str = {"abc2", "431", "ZZ"};
|
||||
n.old_vec_nested = {Nested(9)};
|
||||
n.old_f = -2.5f;
|
||||
n.old_enum = Enum::k3;
|
||||
n.old_bool = false;
|
||||
|
||||
n.new_nested = Nested(55);
|
||||
n.new_bool = true;
|
||||
n.new_vec_nested = {Nested(3), Nested(33), Nested(333)};
|
||||
n.new_f = 4.f;
|
||||
n.new_vec_str = {"4321", "321", "21", "1"};
|
||||
n.new_enum = Enum::k8;
|
||||
n.new2 = 22;
|
||||
n.new_str = "new and even longer";
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
using Span = hwy::Span<const uint32_t>;
|
||||
|
||||
using ReadResult = IFields::ReadResult;
|
||||
void CheckConsumedAll(const ReadResult& result, size_t storage_size) {
|
||||
EXPECT_NE(0, storage_size); // Ensure we notice failure (pos == 0).
|
||||
EXPECT_EQ(storage_size, result.pos);
|
||||
EXPECT_EQ(0, result.missing_fields);
|
||||
EXPECT_EQ(0, result.extra_u32);
|
||||
}
|
||||
|
||||
// If we do not change any fields, Write+Read returns the defaults.
|
||||
TEST(FieldsTest, TestNewMatchesDefaults) {
|
||||
NewFields new_fields;
|
||||
const std::vector<uint32_t> storage = new_fields.Write();
|
||||
|
||||
const ReadResult result = new_fields.Read(Span(storage), 0);
|
||||
CheckConsumedAll(result, storage.size());
|
||||
|
||||
NewFields().CheckEqual(new_fields);
|
||||
}
|
||||
|
||||
// Change fields from default and check that Write+Read returns them again.
|
||||
TEST(FieldsTest, TestRoundTrip) {
|
||||
NewFields new_fields = ModifiedNewFields();
|
||||
const std::vector<uint32_t> storage = new_fields.Write();
|
||||
|
||||
NewFields copy;
|
||||
const ReadResult result = copy.Read(Span(storage), 0);
|
||||
CheckConsumedAll(result, storage.size());
|
||||
|
||||
new_fields.CheckEqual(copy);
|
||||
}
|
||||
|
||||
// Refuse to write invalid floats.
|
||||
TEST(FieldsTest, TestInvalidFloat) {
|
||||
NewFields new_fields;
|
||||
new_fields.new_f = std::numeric_limits<float>::infinity();
|
||||
EXPECT_TRUE(new_fields.Write().empty());
|
||||
|
||||
new_fields.new_f = std::numeric_limits<float>::quiet_NaN();
|
||||
EXPECT_TRUE(new_fields.Write().empty());
|
||||
}
|
||||
|
||||
// Refuse to write invalid strings.
|
||||
TEST(FieldsTest, TestInvalidString) {
|
||||
NewFields new_fields;
|
||||
// Four zero bytes
|
||||
new_fields.new_str.assign(4, '\0');
|
||||
EXPECT_TRUE(new_fields.Write().empty());
|
||||
|
||||
// Too long
|
||||
new_fields.new_str.assign(257, 'a');
|
||||
EXPECT_TRUE(new_fields.Write().empty());
|
||||
|
||||
// First byte not ASCII
|
||||
new_fields.new_str.assign("123");
|
||||
new_fields.new_str[0] = 128;
|
||||
EXPECT_TRUE(new_fields.Write().empty());
|
||||
|
||||
// Upper byte in later u32 not ASCII
|
||||
new_fields.new_str.assign("ABCDEFGH");
|
||||
new_fields.new_str[7] = 255;
|
||||
EXPECT_TRUE(new_fields.Write().empty());
|
||||
}
|
||||
|
||||
// Write two structs to the same storage.
|
||||
TEST(FieldsTest, TestMultipleWrite) {
|
||||
const NewFields modified = ModifiedNewFields();
|
||||
std::vector<uint32_t> storage = modified.Write();
|
||||
const size_t modified_size = storage.size();
|
||||
const NewFields defaults;
|
||||
defaults.AppendTo(storage);
|
||||
|
||||
// Start with defaults to ensure Read retrieves the modified values.
|
||||
NewFields modified_copy;
|
||||
const ReadResult result1 = modified_copy.Read(Span(storage), 0);
|
||||
CheckConsumedAll(result1, modified_size);
|
||||
modified.CheckEqual(modified_copy);
|
||||
|
||||
// Start with modified values to ensure Read retrieves the defaults.
|
||||
NewFields defaults_copy = modified;
|
||||
const ReadResult result2 = defaults_copy.Read(Span(storage), result1.pos);
|
||||
CheckConsumedAll(result2, storage.size());
|
||||
defaults.CheckEqual(defaults_copy);
|
||||
}
|
||||
|
||||
// Write old defaults, read old using new code.
|
||||
TEST(FieldsTest, TestNewCodeOldData) {
|
||||
OldFields old_fields;
|
||||
const std::vector<uint32_t> storage = old_fields.Write();
|
||||
|
||||
// Start with modified old values to ensure old defaults overwrite them.
|
||||
NewFields new_fields = ModifiedNewFields();
|
||||
const ReadResult result = new_fields.Read(Span(storage), 0);
|
||||
MaybePrint(new_fields);
|
||||
EXPECT_NE(0, result.pos); // did not fail
|
||||
EXPECT_NE(0, result.missing_fields);
|
||||
EXPECT_EQ(0, result.extra_u32);
|
||||
old_fields.CheckEqual(new_fields); // old fields are the same in both
|
||||
}
|
||||
|
||||
// Write old defaults, ensure new defaults remain unchanged.
|
||||
TEST(FieldsTest, TestNewCodeOldDataNewUnchanged) {
|
||||
OldFields old_fields;
|
||||
const std::vector<uint32_t> storage = old_fields.Write();
|
||||
|
||||
NewFields new_fields;
|
||||
const ReadResult result = new_fields.Read(Span(storage), 0);
|
||||
MaybePrint(new_fields);
|
||||
EXPECT_NE(0, result.pos); // did not fail
|
||||
EXPECT_NE(0, result.missing_fields);
|
||||
EXPECT_EQ(0, result.extra_u32);
|
||||
NewFields().CheckEqual(new_fields); // new fields match their defaults
|
||||
}
|
||||
|
||||
// Write new defaults, read using old code.
|
||||
TEST(FieldsTest, TestOldCodeNewData) {
|
||||
NewFields new_fields;
|
||||
const std::vector<uint32_t> storage = new_fields.Write();
|
||||
|
||||
OldFields old_fields;
|
||||
const ReadResult result = old_fields.Read(Span(storage), 0);
|
||||
MaybePrint(old_fields);
|
||||
EXPECT_NE(0, result.pos); // did not fail
|
||||
EXPECT_EQ(0, result.missing_fields);
|
||||
EXPECT_NE(0, result.extra_u32);
|
||||
EXPECT_EQ(storage.size(), result.pos + result.extra_u32);
|
||||
|
||||
old_fields.CheckEqual(new_fields); // old fields are the same in both
|
||||
// (Can't check new fields because we only read OldFields)
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
HWY_TEST_MAIN();
|
||||
|
|
@ -25,13 +25,14 @@
|
|||
#include <complex>
|
||||
#include <cstdio>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "util/basics.h" // BF16
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // HWY_INLINE
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
using BF16 = hwy::bfloat16_t;
|
||||
|
||||
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
|
||||
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
|
||||
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
|
||||
|
|
@ -266,8 +267,12 @@ struct PackedSpan {
|
|||
// check the compressed count and ensure we have that many.
|
||||
const size_t required =
|
||||
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
|
||||
HWY_DASSERT(num >= required);
|
||||
(void)required;
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (num < required) {
|
||||
HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed",
|
||||
packed_ofs, num_accessible, required, num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Packed* HWY_RESTRICT ptr;
|
||||
|
|
|
|||
|
|
@ -229,12 +229,12 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
|||
fprintf(stderr,
|
||||
"Date & Time : %s" // dt includes \n
|
||||
"CPU : %s\n"
|
||||
"CPU topology : %s\n"
|
||||
"CPU topology : %s, %s\n"
|
||||
"Instruction set : %s (%zu bits)\n"
|
||||
"Compiled config : %s\n"
|
||||
"Weight Type : %s\n"
|
||||
"EmbedderInput Type : %s\n",
|
||||
dt, cpu100, pools.TopologyString(),
|
||||
dt, cpu100, pools.TopologyString(), pools.PinString(),
|
||||
hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8,
|
||||
CompiledConfig(), StringFromType(loader.Info().weight),
|
||||
TypeName<EmbedderInputT>());
|
||||
|
|
|
|||
|
|
@ -72,18 +72,11 @@ struct Activations {
|
|||
size_t seq_len;
|
||||
size_t cache_pos_size = 0;
|
||||
|
||||
// Multi-Head Attention?
|
||||
bool IsMHA() const { return layer_config.heads == layer_config.kv_heads; }
|
||||
|
||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
size_t QStride() const { return layer_config.qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim,
|
||||
PostQKType post_qk) {
|
||||
const size_t rope_dim =
|
||||
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
|
||||
RowVectorBatch<float> inv_timescale(1, rope_dim / 2);
|
||||
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
|
||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||
const float freq_exponents =
|
||||
static_cast<float>(2 * dim) / static_cast<float>(rope_dim);
|
||||
|
|
@ -100,29 +93,31 @@ struct Activations {
|
|||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
const size_t vocab_size = weights_config.vocab_size;
|
||||
|
||||
x = RowVectorBatch<float>(batch_size, model_dim);
|
||||
q = RowVectorBatch<float>(batch_size, layer_config.heads * QStride());
|
||||
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
q = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.QStride()));
|
||||
if (vocab_size > 0) {
|
||||
logits = RowVectorBatch<float>(batch_size, vocab_size);
|
||||
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
|
||||
}
|
||||
|
||||
pre_att_rms_out = RowVectorBatch<float>(batch_size, model_dim);
|
||||
att = RowVectorBatch<float>(batch_size,
|
||||
layer_config.heads * weights_config.seq_len);
|
||||
att_out = RowVectorBatch<float>(batch_size,
|
||||
layer_config.heads * layer_config.qkv_dim);
|
||||
att_sums = RowVectorBatch<float>(batch_size, model_dim);
|
||||
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
att = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, layer_config.heads * weights_config.seq_len));
|
||||
att_out = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim));
|
||||
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
|
||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, model_dim);
|
||||
C1 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
|
||||
C2 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
|
||||
ffw_out = RowVectorBatch<float>(batch_size, model_dim);
|
||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
|
||||
C1 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
||||
C2 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
||||
ffw_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
|
||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
||||
griffin_x = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_y = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_gate_x = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_multiplier = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_y = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_gate_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_multiplier =
|
||||
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
}
|
||||
|
||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
|
||||
|
|
|
|||
|
|
@ -198,17 +198,15 @@ static ModelConfig ConfigGriffin2B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma_224() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_224";
|
||||
config.model = Model::PALIGEMMA_224;
|
||||
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
|
||||
static void AddVitConfig(ModelConfig& config) {
|
||||
config.vit_model_dim = 1152;
|
||||
config.vocab_size = 256000 + 1024 + 128; // = 257152
|
||||
config.image_size = 224;
|
||||
config.patch_width = 14;
|
||||
const size_t num_patches = config.image_size / config.patch_width;
|
||||
config.vit_seq_len = num_patches * num_patches;
|
||||
LayerConfig layer_config = {
|
||||
LayerConfig vit_layer_config = {
|
||||
.model_dim = config.vit_model_dim,
|
||||
.ff_hidden_dim = 4304,
|
||||
.heads = 16,
|
||||
|
|
@ -217,8 +215,15 @@ static ModelConfig ConfigPaliGemma_224() {
|
|||
.ff_biases = true,
|
||||
.type = LayerAttentionType::kVit,
|
||||
};
|
||||
config.vit_layer_configs = {27, layer_config};
|
||||
config.vit_layer_configs = {27, vit_layer_config};
|
||||
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma_224() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_224";
|
||||
config.model = Model::PALIGEMMA_224;
|
||||
AddVitConfig(config);
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
|
@ -118,6 +119,13 @@ enum class Model {
|
|||
struct LayerConfig {
|
||||
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
|
||||
|
||||
// Multi-Head Attention?
|
||||
bool IsMHA() const { return heads == kv_heads; }
|
||||
|
||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
size_t model_dim = 0;
|
||||
size_t griffin_dim = 0;
|
||||
size_t ff_hidden_dim = 0;
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
|
|
@ -31,6 +31,7 @@
|
|||
// Placeholder for internal test4, do not remove
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -232,49 +233,49 @@ class GemmaAttention {
|
|||
// KV directly to KVCache.
|
||||
HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
// For the computation of Q, K, and V, it is useful to remember that
|
||||
// qkv_einsum_w has shape [(layer_config_.heads + layer_config_.kv_heads *
|
||||
// 2), kKQVDim, layer_config_.model_dim] and q_stride_ =
|
||||
// layer_config_.qkv_dim * (is_mha_ ? 3 : 1);
|
||||
const size_t model_dim = layer_config_.model_dim;
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
const size_t kv_heads = layer_config_.kv_heads;
|
||||
|
||||
const auto pre_att_rms_out =
|
||||
ConstMat(activations_.pre_att_rms_out.All(), layer_config_.model_dim);
|
||||
const auto w_q1 = layer_weights_.qkv_einsum_w.data() == nullptr
|
||||
? ConstMat(layer_weights_.qkv_einsum_w1.data(),
|
||||
layer_config_.model_dim)
|
||||
: ConstMat(layer_weights_.qkv_einsum_w.data(),
|
||||
layer_config_.model_dim);
|
||||
const auto w_q2 =
|
||||
layer_weights_.qkv_einsum_w.data() == nullptr
|
||||
? ConstMat(layer_weights_.qkv_einsum_w2.data(),
|
||||
layer_config_.model_dim)
|
||||
: ConstMat(layer_weights_.qkv_einsum_w.data(),
|
||||
layer_config_.model_dim, layer_config_.model_dim,
|
||||
layer_config_.heads * layer_config_.qkv_dim *
|
||||
layer_config_.model_dim);
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_interleaved, pre_att_rms_out, w_q1,
|
||||
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, activations_.env,
|
||||
MutableMat(activations_.q.All(), layer_config_.heads * q_stride_));
|
||||
ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out);
|
||||
auto w_q1 = layer_weights_.qkv_einsum_w.data()
|
||||
? ConstMatFromWeights(layer_weights_.qkv_einsum_w)
|
||||
: ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
|
||||
// We must shrink to the actual size because MatMul verifies
|
||||
// `B.extents.rows == C.Cols()`. If MHA, `QStride() == 3 * qkv_dim` and all
|
||||
// rows are used. Otherwise, `QStride() == qkv_dim` and KV will be
|
||||
// computed in the second MatMul.
|
||||
const size_t w1_rows = heads * layer_config_.QStride();
|
||||
w_q1.ShrinkRows(w1_rows);
|
||||
MatMul(pre_att_rms_out, w_q1,
|
||||
/*add=*/nullptr, activations_.env, RowPtrFromBatch(activations_.q));
|
||||
|
||||
if (is_mha_) {
|
||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||
} else {
|
||||
auto w_q2 = layer_weights_.qkv_einsum_w.data()
|
||||
? ConstMatFromWeights(layer_weights_.qkv_einsum_w,
|
||||
w1_rows * model_dim)
|
||||
: ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
|
||||
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
|
||||
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
|
||||
w_q2.ShrinkRows(w_rows_kv_cols);
|
||||
|
||||
// Single query and no wraparound means we can use a matmul and write
|
||||
// directly into the KV cache with a stride of cache_pos_size_.
|
||||
if (num_queries_ == 1 &&
|
||||
queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) {
|
||||
const size_t kv_ofs =
|
||||
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
// KV structure is [k, v, k, v, ....] = layer_config_.kv_heads pairs of
|
||||
// (k, v).
|
||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_tokens_, pre_att_rms_out, w_q2,
|
||||
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
|
||||
activations_.env,
|
||||
MutableMat(kv, layer_config_.kv_heads * 2 * layer_config_.qkv_dim,
|
||||
cache_pos_size_));
|
||||
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
||||
kv_rows.SetStride(cache_pos_size_);
|
||||
MatMul(pre_att_rms_out, w_q2,
|
||||
/*add=*/nullptr, activations_.env, kv_rows);
|
||||
} else {
|
||||
// Proceed row by row because there will be wraparound.
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
|
|
@ -288,38 +289,32 @@ class GemmaAttention {
|
|||
const size_t kv_offset =
|
||||
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
// KV structure is [k, v, k, v, ....] = layer_config_.kv_heads pairs
|
||||
// of (k, v).
|
||||
if (layer_weights_.qkv_einsum_w.data() == nullptr) {
|
||||
MatVec(layer_weights_.qkv_einsum_w2, 0,
|
||||
layer_config_.kv_heads * 2 * layer_config_.qkv_dim,
|
||||
layer_config_.model_dim, x, kv, pool_);
|
||||
if (layer_weights_.qkv_einsum_w.data()) {
|
||||
MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim,
|
||||
w_rows_kv_cols, model_dim, x, kv, pool_);
|
||||
} else {
|
||||
MatVec(layer_weights_.qkv_einsum_w,
|
||||
layer_config_.heads * layer_config_.qkv_dim *
|
||||
layer_config_.model_dim,
|
||||
layer_config_.kv_heads * 2 * layer_config_.qkv_dim,
|
||||
layer_config_.model_dim, x, kv, pool_);
|
||||
MatVec(layer_weights_.qkv_einsum_w2, 0, //
|
||||
w_rows_kv_cols, model_dim, x, kv, pool_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // !is_mha_
|
||||
|
||||
// Self-extension
|
||||
const hwy::Divisor div_grp_size(
|
||||
static_cast<uint32_t>(layer_config_.grp_size));
|
||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||
pool_.Run(0, layer_config_.kv_heads * num_interleaved,
|
||||
pool_.Run(0, kv_heads * num_interleaved,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % layer_config_.kv_heads;
|
||||
const size_t interleaved_idx = task / layer_config_.kv_heads;
|
||||
const size_t head = task % kv_heads;
|
||||
const size_t interleaved_idx = task / kv_heads;
|
||||
const size_t query_idx = interleaved_idx % num_queries_;
|
||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||
size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||
layer_ * cache_layer_size_ +
|
||||
head * layer_config_.qkv_dim * 2;
|
||||
head * qkv_dim * 2;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
|
||||
const size_t ngb_size = layer_config_.ngb_size;
|
||||
|
|
@ -328,7 +323,7 @@ class GemmaAttention {
|
|||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
const float* HWY_RESTRICT mha_kv =
|
||||
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
||||
layer_config_.qkv_dim;
|
||||
qkv_dim;
|
||||
|
||||
// In self-extend, when embedding position,
|
||||
// we will use grouped key position
|
||||
|
|
@ -340,9 +335,8 @@ class GemmaAttention {
|
|||
kv);
|
||||
// If MHA, also copy V into KVCache.
|
||||
if (is_mha_) {
|
||||
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
|
||||
kv + layer_config_.qkv_dim,
|
||||
layer_config_.qkv_dim * sizeof(*kv));
|
||||
hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim,
|
||||
qkv_dim * sizeof(*kv));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -484,27 +478,14 @@ class GemmaAttention {
|
|||
HWY_DASSERT(layer_weights_.att_weights.data() != nullptr);
|
||||
HWY_DASSERT(activations_.att_out.All() != nullptr);
|
||||
HWY_DASSERT(activations_.att_sums.All() != nullptr);
|
||||
if (layer_weights_.layer_config.softmax_attn_output_biases) {
|
||||
MatMul</*kAdd=*/true>(
|
||||
num_interleaved,
|
||||
ConstMat(activations_.att_out.All(),
|
||||
layer_config_.heads * layer_config_.qkv_dim),
|
||||
ConstMat(layer_weights_.att_weights.data(),
|
||||
layer_config_.heads * layer_config_.qkv_dim),
|
||||
layer_weights_.att_weights.scale(),
|
||||
layer_weights_.attention_output_biases.data_scale1(),
|
||||
activations_.env,
|
||||
MutableMat(activations_.att_sums.All(), layer_config_.model_dim));
|
||||
} else {
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_interleaved,
|
||||
ConstMat(activations_.att_out.All(),
|
||||
layer_config_.heads * layer_config_.qkv_dim),
|
||||
ConstMat(layer_weights_.att_weights.data(),
|
||||
layer_config_.heads * layer_config_.qkv_dim),
|
||||
layer_weights_.att_weights.scale(), nullptr, activations_.env,
|
||||
MutableMat(activations_.att_sums.All(), layer_config_.model_dim));
|
||||
}
|
||||
|
||||
const float* add =
|
||||
layer_weights_.layer_config.softmax_attn_output_biases
|
||||
? layer_weights_.attention_output_biases.data_scale1()
|
||||
: nullptr;
|
||||
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
|
||||
ConstMatFromWeights(layer_weights_.att_weights), add,
|
||||
activations_.env, RowPtrFromBatch(activations_.att_sums));
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -545,13 +526,13 @@ class GemmaAttention {
|
|||
num_queries_(queries_pos.size()),
|
||||
num_tokens_(num_tokens),
|
||||
layer_(layer),
|
||||
q_stride_(activations.QStride()),
|
||||
layer_config_(layer_weights->layer_config),
|
||||
q_stride_(layer_config_.QStride()),
|
||||
cache_layer_size_(layer_weights->layer_config.CacheLayerSize()),
|
||||
cache_pos_size_(activations.cache_pos_size),
|
||||
is_mha_(activations.IsMHA()),
|
||||
is_mha_(layer_config_.IsMHA()),
|
||||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
layer_config_(layer_weights->layer_config),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
pool_(activations.env.Pool()) {
|
||||
|
|
@ -573,6 +554,7 @@ class GemmaAttention {
|
|||
const size_t num_queries_;
|
||||
const size_t num_tokens_;
|
||||
const size_t layer_;
|
||||
const LayerConfig& layer_config_;
|
||||
const size_t q_stride_ = 0;
|
||||
const size_t cache_layer_size_ = 0;
|
||||
const size_t cache_pos_size_ = 0;
|
||||
|
|
@ -580,7 +562,6 @@ class GemmaAttention {
|
|||
|
||||
Activations& activations_;
|
||||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const LayerConfig& layer_config_;
|
||||
const hwy::Divisor& div_seq_len_;
|
||||
const KVCaches& kv_caches_;
|
||||
hwy::ThreadPool& pool_;
|
||||
|
|
@ -622,17 +603,13 @@ class VitAttention {
|
|||
// Computes Q, K, V for all heads, stored in activations_.q.
|
||||
HWY_NOINLINE void ComputeQKV() {
|
||||
PROFILER_ZONE("Gen.VitAttention.QKV");
|
||||
const auto y =
|
||||
ConstMat(activations_.pre_att_rms_out.All(), layer_config_.model_dim);
|
||||
auto& qkv = activations_.q;
|
||||
HWY_ASSERT(qkv.BatchSize() == num_tokens_);
|
||||
HWY_ASSERT(qkv.Len() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
MatMul</*kAdd=*/true>(
|
||||
num_tokens_, y,
|
||||
ConstMat(layer_weights_.vit.qkv_einsum_w.data_scale1(),
|
||||
layer_config_.model_dim),
|
||||
/*scale=*/1.0f, layer_weights_.vit.qkv_einsum_b.data_scale1(),
|
||||
activations_.env, MutableMat(qkv.All(), qkv.Len()));
|
||||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
|
||||
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
|
||||
layer_weights_.vit.qkv_einsum_b.data_scale1(), activations_.env,
|
||||
RowPtrFromBatch(qkv));
|
||||
}
|
||||
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
||||
|
|
@ -679,17 +656,13 @@ class VitAttention {
|
|||
HWY_NOINLINE void SumHeads() {
|
||||
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
||||
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
|
||||
auto att_out = ConstMat(activations_.att_out.All(),
|
||||
layer_config_.heads * layer_config_.qkv_dim);
|
||||
auto att_weights = ConstMat(layer_weights_.vit.attn_out_w.data_scale1(),
|
||||
layer_config_.heads * layer_config_.qkv_dim);
|
||||
auto att_sums =
|
||||
MutableMat(activations_.att_sums.All(), layer_config_.model_dim);
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
MatMul</*kAdd=*/true>(num_tokens_, att_out, att_weights, /*scale=*/1.0f,
|
||||
bias, activations_.env, att_sums);
|
||||
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
||||
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
||||
auto att_sums = RowPtrFromBatch(activations_.att_sums);
|
||||
MatMul(att_out, att_weights, bias, activations_.env, att_sums);
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -741,125 +714,94 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
|||
PROFILER_ZONE("Gen.FFW");
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
using WeightType = T;
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||
Mat<const WeightType> w1;
|
||||
const float* bias1 = nullptr;
|
||||
Mat<const WeightType> w2;
|
||||
const float* bias2 = nullptr;
|
||||
float scale = 1.0f;
|
||||
Mat<const WeightType> w_output;
|
||||
const float* output_bias = nullptr;
|
||||
float output_scale = 1.0f;
|
||||
auto hidden_activations = MutableMat(activations.C1.All(), ffh_hidden_dim);
|
||||
auto multiplier = MutableMat(activations.C2.All(), ffh_hidden_dim);
|
||||
auto ffw_out = MutableMat(activations.ffw_out.All(), model_dim);
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
const float* bias1 =
|
||||
add_bias ? layer_weights->ffw_gating_biases.data_scale1() : nullptr;
|
||||
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
|
||||
const float* output_bias =
|
||||
add_bias ? layer_weights->ffw_output_biases.data_scale1() : nullptr;
|
||||
|
||||
// For some of the weights and activations, it depends on the config where to
|
||||
// get them from or whether to use them at all.
|
||||
bias1 = layer_weights->ffw_gating_biases.data_scale1();
|
||||
bias2 = bias1 + ffh_hidden_dim;
|
||||
output_bias = layer_weights->ffw_output_biases.data_scale1();
|
||||
w1 = layer_weights->gating_einsum_w.data() == nullptr
|
||||
? ConstMat(layer_weights->gating_einsum_w1.data(), model_dim)
|
||||
: ConstMat(layer_weights->gating_einsum_w.data(), model_dim);
|
||||
w2 = layer_weights->gating_einsum_w.data() == nullptr
|
||||
? ConstMat(layer_weights->gating_einsum_w2.data(), model_dim)
|
||||
: ConstMat(layer_weights->gating_einsum_w.data(), model_dim,
|
||||
model_dim, model_dim * ffh_hidden_dim);
|
||||
scale = layer_weights->gating_einsum_w.data() == nullptr
|
||||
? layer_weights->gating_einsum_w1.scale()
|
||||
: layer_weights->gating_einsum_w.scale();
|
||||
w_output = ConstMat(layer_weights->linear_w.data(), ffh_hidden_dim);
|
||||
output_scale = layer_weights->linear_w.scale();
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
auto hidden_activations = RowPtrFromBatch(activations.C1);
|
||||
auto multiplier = RowPtrFromBatch(activations.C2);
|
||||
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
|
||||
|
||||
// gating_einsum_w holds two half-matrices. We plan to change the importer to
|
||||
// avoid this confusion by splitting into gating_einsum_w1 and
|
||||
// gating_einsum_w2.
|
||||
const bool split = !!layer_weights->gating_einsum_w.data();
|
||||
auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w)
|
||||
: ConstMatFromWeights(layer_weights->gating_einsum_w1);
|
||||
auto w2 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w,
|
||||
model_dim * ffh_hidden_dim)
|
||||
: ConstMatFromWeights(layer_weights->gating_einsum_w2);
|
||||
if (split) {
|
||||
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
|
||||
w1.ShrinkRows(ffh_hidden_dim);
|
||||
w2.ShrinkRows(ffh_hidden_dim);
|
||||
}
|
||||
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
if (add_bias) {
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, x, w1, scale, bias1,
|
||||
activations.env, hidden_activations);
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, x, w2, scale, bias2,
|
||||
activations.env, multiplier);
|
||||
} else {
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, x, w1, scale, bias1,
|
||||
activations.env, hidden_activations);
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, x, w2, scale, bias2,
|
||||
activations.env, multiplier);
|
||||
}
|
||||
MatMul(x, w1, bias1, activations.env, hidden_activations);
|
||||
MatMul(x, w2, bias2, activations.env, multiplier);
|
||||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.ptr,
|
||||
multiplier.ptr, ffh_hidden_dim * num_interleaved);
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
||||
multiplier.Row(0), ffh_hidden_dim * num_interleaved);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
if (add_bias) {
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, ConstMat(hidden_activations),
|
||||
w_output, output_scale, output_bias,
|
||||
activations.env, ffw_out);
|
||||
} else {
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, ConstMat(hidden_activations),
|
||||
w_output, output_scale, output_bias,
|
||||
activations.env, ffw_out);
|
||||
}
|
||||
auto activations_mat = MakeConstMat(
|
||||
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));
|
||||
|
||||
MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
|
||||
}
|
||||
|
||||
// Same as FFWNoVit, but with different layer_weights members and no second
|
||||
// gating matrix.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
||||
const LayerWeightsPtrs<T>* layer_weights) {
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
using WeightType = typename LayerWeightsPtrs<T>::WeightF32OrBF16;
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||
Mat<const WeightType> w1;
|
||||
const float* bias1 = nullptr;
|
||||
float scale = 1.0f;
|
||||
Mat<const WeightType> w_output;
|
||||
const float* output_bias = nullptr;
|
||||
float output_scale = 1.0f;
|
||||
auto hidden_activations = MutableMat(activations.C1.All(), ff_hidden_dim);
|
||||
auto multiplier = MutableMat(activations.C2.All(), ff_hidden_dim);
|
||||
auto ffw_out = MutableMat(activations.ffw_out.All(), model_dim);
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
const float* bias1 =
|
||||
add_bias ? layer_weights->vit.linear_0_b.data_scale1() : nullptr;
|
||||
const float* output_bias =
|
||||
add_bias ? layer_weights->vit.linear_1_b.data_scale1() : nullptr;
|
||||
|
||||
// For some of the weights and activations, it depends on the config where to
|
||||
// get them from or whether to use them at all.
|
||||
w1 = ConstMat(layer_weights->vit.linear_0_w.data_scale1(), model_dim);
|
||||
bias1 = layer_weights->vit.linear_0_b.data_scale1();
|
||||
multiplier.ptr = nullptr;
|
||||
w_output =
|
||||
ConstMat(layer_weights->vit.linear_1_w.data_scale1(), ff_hidden_dim);
|
||||
output_bias = layer_weights->vit.linear_1_b.data_scale1();
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
auto hidden_activations = RowPtrFromBatch(activations.C1);
|
||||
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
|
||||
|
||||
auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
|
||||
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
if (add_bias) {
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, x, w1, scale, bias1,
|
||||
activations.env, hidden_activations);
|
||||
} else {
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, x, w1, scale, bias1,
|
||||
activations.env, hidden_activations);
|
||||
}
|
||||
MatMul(x, w1, bias1, activations.env, hidden_activations);
|
||||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.ptr,
|
||||
multiplier.ptr, ff_hidden_dim * num_interleaved);
|
||||
// Activation (Gelu), store in act.
|
||||
RowPtrF multiplier = RowPtrF(nullptr, 0);
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
||||
multiplier.Row(0), ff_hidden_dim * num_interleaved);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
if (add_bias) {
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, ConstMat(hidden_activations),
|
||||
w_output, output_scale, output_bias,
|
||||
activations.env, ffw_out);
|
||||
} else {
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, ConstMat(hidden_activations),
|
||||
w_output, output_scale, output_bias,
|
||||
activations.env, ffw_out);
|
||||
}
|
||||
auto activations_mat = MakeConstMat(
|
||||
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));
|
||||
|
||||
MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
|
|
@ -874,7 +816,7 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
|||
// Image tokens just need to be copied.
|
||||
if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
|
||||
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
|
||||
x.Len() * sizeof(x.Const()[0]));
|
||||
x.Cols() * sizeof(x.Const()[0]));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -963,7 +905,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
|
|||
// the Big Vision codebase. See
|
||||
// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py
|
||||
// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and
|
||||
// try mergig this with TransformerLayer.
|
||||
// try merging this with TransformerLayer.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
|
|
@ -974,7 +916,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
|||
|
||||
auto& x = activations.x;
|
||||
HWY_DASSERT(x.BatchSize() == num_tokens);
|
||||
HWY_DASSERT(x.Len() == model_dim);
|
||||
HWY_DASSERT(x.Cols() == model_dim);
|
||||
|
||||
// y = nn.LayerNorm()(x)
|
||||
// y ~ pre_att_rms_out
|
||||
|
|
@ -1127,7 +1069,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
const size_t patch_size = patch_width * patch_width * 3;
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
|
||||
patch_size * model_dim);
|
||||
HWY_DASSERT(activations.x.Len() == model_dim);
|
||||
HWY_DASSERT(activations.x.Cols() == model_dim);
|
||||
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
image_patches[i] = hwy::AllocateAligned<float>(patch_size);
|
||||
|
|
@ -1139,11 +1081,11 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// This could be done as one MatMul like:
|
||||
// RowVectorBatch<float> image_patches(kSeqLen, kPatchSize);
|
||||
// [Get patches]
|
||||
// MatMul</*kAdd=*/true>(
|
||||
// kVitSeqLen, ConstMat(image_patches.All(), kPatchSize),
|
||||
// ConstMat(weights.vit_img_embedding_kernel.data_scale1(), kPatchSize),
|
||||
// /*scale=*/1.0f, weights.vit_img_embedding_bias.data_scale1(),
|
||||
// activations.env, MutableMat(activations.x.All(), kVitModelDim));
|
||||
// MatMul(
|
||||
// MatFromBatch(kVitSeqLen, image_patches),
|
||||
// MatFromWeights(weights.vit_img_embedding_kernel),
|
||||
// weights.vit_img_embedding_bias.data_scale1(), activations.env,
|
||||
// RowPtrF(activations.x.All(), kVitModelDim));
|
||||
// However, MatMul currently requires that
|
||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||
// which is not the case here. We should relax that requirement on MatMul and
|
||||
|
|
@ -1184,11 +1126,10 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
activations.x.All(), vit_model_dim);
|
||||
|
||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMul</*kAdd=*/true>(
|
||||
num_tokens, ConstMat(activations.x.All(), vit_model_dim),
|
||||
ConstMat(weights.vit_img_head_kernel.data_scale1(), vit_model_dim),
|
||||
/*scale=*/1.0f, weights.vit_img_head_bias.data_scale1(), activations.env,
|
||||
MutableMat(image_tokens.All(), weights.weights_config.model_dim));
|
||||
MatMul(ConstMatFromBatch(num_tokens, activations.x),
|
||||
ConstMatFromWeights(weights.vit_img_head_kernel),
|
||||
weights.vit_img_head_bias.data_scale1(), activations.env,
|
||||
RowPtrFromBatch(image_tokens));
|
||||
}
|
||||
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
|
|
@ -1320,7 +1261,6 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
const QueriesPos& queries_prefix_end,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t model_dim = model.Config().model_dim;
|
||||
const size_t vocab_size = model.Config().vocab_size;
|
||||
const ModelWeightsPtrs<T>& weights = *model.GetWeightsOfType<T>();
|
||||
|
||||
|
|
@ -1408,11 +1348,10 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
{
|
||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_queries, ConstMat(activations.x.All(), model_dim),
|
||||
ConstMat(weights.embedder_input_embedding.data(), model_dim),
|
||||
weights.embedder_input_embedding.scale(), /*add=*/nullptr,
|
||||
activations.env, MutableMat(activations.logits.All(), vocab_size));
|
||||
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
||||
ConstMatFromWeights(weights.embedder_input_embedding),
|
||||
/*add=*/nullptr, activations.env,
|
||||
RowPtrFromBatch(activations.logits));
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@
|
|||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -119,12 +118,12 @@ struct GenerateImageTokensT {
|
|||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
pools_.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateSingleT>(
|
||||
runtime_config, prompt, pos, prefix_end, kv_cache, pools_, timing_info);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
pools_.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
|
|
@ -141,23 +140,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||
}
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
pools_.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateBatchT>(
|
||||
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
|
||||
kv_caches, pools_, timing_info);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
pools_.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens) {
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
pools_.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
|
||||
image_tokens, pools_);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
pools_.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
|
||||
|
|
|
|||
|
|
@ -121,7 +121,11 @@ struct RuntimeConfig {
|
|||
const ImageTokens *image_tokens = nullptr;
|
||||
|
||||
// Whether to use thread spinning to reduce barrier synchronization latency.
|
||||
bool use_spinning = true;
|
||||
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
|
||||
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
|
||||
// default decision is likely sufficient because it is based on whether
|
||||
// threads are successfully pinned.
|
||||
mutable Tristate use_spinning = Tristate::kDefault;
|
||||
|
||||
// End-of-sequence token.
|
||||
int eos_id = EOS_ID;
|
||||
|
|
|
|||
43
gemma/run.cc
43
gemma/run.cc
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
|
|
@ -77,8 +78,8 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
|||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
||||
int verbosity, const AcceptFunc& accept_token,
|
||||
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||
const InferenceArgs& args, const AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // across turns
|
||||
|
|
@ -90,13 +91,23 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
|
||||
const bool have_image = !args.image_file.path.empty();
|
||||
Image image;
|
||||
ImageTokens image_tokens(256, 2048);
|
||||
ImageTokens image_tokens;
|
||||
if (have_image) {
|
||||
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
|
||||
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen};
|
||||
RuntimeConfig runtime_config = {
|
||||
.verbosity = app.verbosity, .gen = &gen, .use_spinning = app.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
if (app.verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||
static_cast<int>(image_tokens_duration * 1000));
|
||||
}
|
||||
}
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
|
|
@ -111,7 +122,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
abs_pos = 0;
|
||||
InitGenerator(args, gen);
|
||||
}
|
||||
if (verbosity >= 2) {
|
||||
if (app.verbosity >= 2) {
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
} else {
|
||||
|
|
@ -122,7 +133,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
if (tokens_generated_this_turn == prompt_size + 1) {
|
||||
// first token of response
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (verbosity >= 1) {
|
||||
if (app.verbosity >= 1) {
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
}
|
||||
|
|
@ -133,7 +144,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
|
||||
while (true) { // Loop until user quits.
|
||||
tokens_generated_this_turn = 0;
|
||||
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
|
||||
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
|
||||
if (!std::cin) return;
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||
|
|
@ -160,13 +171,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
}
|
||||
}
|
||||
|
||||
TimingInfo timing_info = {.verbosity = verbosity};
|
||||
RuntimeConfig runtime_config = {
|
||||
.verbosity = verbosity,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
||||
RuntimeConfig runtime_config = {.verbosity = app.verbosity,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
.use_spinning = app.spin};
|
||||
args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
|
|
@ -226,8 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
|
||||
ReplGemma(model, kv_cache, inference, app.verbosity, AcceptFunc(),
|
||||
app.eot_line);
|
||||
ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -21,9 +21,10 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h" // Wrap
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "gemma/common.h" // Wrap
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/profiler.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
|
|
@ -109,7 +110,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
|||
}
|
||||
|
||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
||||
if (info.model == Model::PALIGEMMA_224) {
|
||||
if (info.training == ModelTraining::PALIGEMMA) {
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
||||
|
|
|
|||
|
|
@ -95,11 +95,11 @@ struct LayerWeightsPtrs {
|
|||
config.model_dim},
|
||||
.qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads),
|
||||
config.qkv_dim},
|
||||
.linear_0_w = {"linear_0_w", config.model_dim,
|
||||
config.ff_hidden_dim},
|
||||
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
|
||||
.linear_1_w = {"linear_1_w", config.ff_hidden_dim,
|
||||
.linear_0_w = {"linear_0_w", config.ff_hidden_dim,
|
||||
config.model_dim},
|
||||
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
|
||||
.linear_1_w = {"linear_1_w", config.model_dim,
|
||||
config.ff_hidden_dim},
|
||||
.linear_1_b = {"linear_1_b", 1, config.model_dim},
|
||||
.layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim},
|
||||
.layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim},
|
||||
|
|
@ -349,14 +349,13 @@ struct ModelWeightsPtrs {
|
|||
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
|
||||
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
|
||||
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
|
||||
vit_img_embedding_kernel(
|
||||
"img_emb_kernel",
|
||||
config.patch_width * config.patch_width * 3,
|
||||
config.vit_model_dim),
|
||||
vit_img_embedding_kernel("img_emb_kernel",
|
||||
config.patch_width * config.patch_width * 3,
|
||||
config.vit_model_dim),
|
||||
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
|
||||
vit_img_head_bias("img_head_bias", 1, config.model_dim),
|
||||
vit_img_head_kernel("img_head_kernel", config.vit_model_dim,
|
||||
config.model_dim),
|
||||
vit_img_head_kernel("img_head_kernel", config.model_dim,
|
||||
config.vit_model_dim),
|
||||
scale_names(config.scale_names),
|
||||
weights_config(config) {
|
||||
c_layers.reserve(config.layer_configs.size());
|
||||
|
|
|
|||
|
|
@ -1011,14 +1011,14 @@ struct TestShortDotsT {
|
|||
// hence they require padding to one vector.
|
||||
const size_t padded_num = hwy::RoundUpTo(num, N);
|
||||
const size_t packed_num = CompressedArrayElements<Packed>(num);
|
||||
RowVectorBatch<float> raw_w(1, padded_num);
|
||||
RowVectorBatch<float> raw_v(1, padded_num);
|
||||
RowVectorBatch<Packed> weights(1, packed_num);
|
||||
RowVectorBatch<float> raw_w(Extents2D(1, padded_num));
|
||||
RowVectorBatch<float> raw_v(Extents2D(1, padded_num));
|
||||
RowVectorBatch<Packed> weights(Extents2D(1, packed_num));
|
||||
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
|
||||
RowVectorBatch<T> vectors(1, num);
|
||||
RowVectorBatch<T> vectors(Extents2D(1, num));
|
||||
const PackedSpan<T> v(vectors.Batch(0), num);
|
||||
|
||||
RowVectorBatch<double> bufs(1, num);
|
||||
RowVectorBatch<double> bufs(Extents2D(1, num));
|
||||
double* HWY_RESTRICT buf = bufs.Batch(0);
|
||||
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||
|
|
@ -1107,11 +1107,11 @@ void TestAllDot() {
|
|||
|
||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||
const size_t num = 24 * 1024;
|
||||
NestedPools pools(kMaxWorkers - 1, /*pin=*/1, BoundedSlice(0, 1),
|
||||
BoundedSlice(0, 1));
|
||||
RowVectorBatch<float> a(kMaxWorkers, num);
|
||||
RowVectorBatch<float> b(kMaxWorkers, num);
|
||||
RowVectorBatch<double> bufs(kMaxWorkers, num);
|
||||
NestedPools pools(kMaxWorkers - 1, /*pin=*/Tristate::kDefault,
|
||||
BoundedSlice(0, 1), BoundedSlice(0, 1));
|
||||
RowVectorBatch<float> a(Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<float> b(Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num));
|
||||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
|
||||
|
|
|
|||
309
ops/matmul-inl.h
309
ops/matmul-inl.h
|
|
@ -16,8 +16,9 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "compression/compress.h" // IWYU pragma: keep, b/conditionally used
|
||||
#include "ops/matmul.h" // IWYU pragma: export
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
|
|
@ -30,7 +31,7 @@
|
|||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
@ -53,38 +54,20 @@ constexpr size_t kRegCols = 4;
|
|||
// generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
|
||||
constexpr size_t kRegRows = kRegCols;
|
||||
|
||||
// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are
|
||||
// more efficient than f32 * f32 + f32 because they process twice as many lanes
|
||||
// at a time. Any combination of A and B can be bf16: activations may already be
|
||||
// bf16, and weights can be decompressed to bf16.
|
||||
//
|
||||
// The corresponding op is `ReorderWidenMulAccumulate`, and it is always
|
||||
// supported, but only useful if it returns a single vector of pairwise sums
|
||||
// `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReorderWidenMulAccumulate`
|
||||
// insteads return `a[1] * b[1]` in its `sum1` output. We cannot afford to keep
|
||||
// a `sum1` for each of the `kRegRows * kRegCols` C vectors, and it would be
|
||||
// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B
|
||||
// to bf16 if the native op is available. This will actually demote f32
|
||||
// activations to bf16. Otherwise, we decompress to f32 and use normal FMA.
|
||||
using MulT = hwy::If<HWY_NATIVE_DOT_BF16, BF16, float>;
|
||||
|
||||
// Loads two vectors at a time with element type MulT from a row of transposed
|
||||
// B. Called in a loop over col_ab. No bounds checking because `kRow` is
|
||||
// actually from B columns, which we checked is a multiple of `kRegCols`.
|
||||
// Loads two vectors at a time with element type hn::TFromD<DR> from a row of
|
||||
// transposed B. Called in a loop over col_ab. No bounds checking because
|
||||
// `kRow` is from B columns, which we checked is a multiple of `kRegCols`.
|
||||
template <size_t kRow, typename MatTB>
|
||||
class BRow {
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
|
||||
public:
|
||||
BRow(const Mat<const MatTB>& B, size_t row_b, size_t cols_c)
|
||||
// B.cols * C.cols is the total number of elements, required for
|
||||
// PackedSpan::BoundsCheck.
|
||||
: B_(MakeSpan(B.ptr, B.ofs + B.cols * cols_c)),
|
||||
B_ofs_(B.Row(row_b + kRow)) {}
|
||||
BRow(const ConstMat<MatTB>& B, size_t row_b)
|
||||
: B_(MakeSpan(B.ptr, B.ofs + B.Extents().Area())),
|
||||
B_ofs_(B.Row(HWY_MIN(row_b + kRow, B.Extents().rows - 1))) {}
|
||||
|
||||
template <class DM, class VM = hn::Vec<DM>>
|
||||
HWY_INLINE void Load2(DM d, size_t col_ab, VM& b0, VM& b1) const {
|
||||
static_assert(hwy::IsSame<hn::TFromD<DM>, MulT>());
|
||||
template <class DR, class VR = hn::Vec<DR>>
|
||||
HWY_INLINE void Load2(DR d, size_t col_ab, VR& b0, VR& b1) const {
|
||||
Decompress2(d, B_, B_ofs_ + col_ab, b0, b1);
|
||||
}
|
||||
|
||||
|
|
@ -93,11 +76,11 @@ class BRow {
|
|||
const size_t B_ofs_;
|
||||
};
|
||||
|
||||
// Loads *two* row vectors from A via `Decompress2`, multiplies element-wise
|
||||
// with `kRegRows` x 2 row vectors from transposed B, and adds them to
|
||||
// `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a subset of
|
||||
// the terms of the dot products that make up the MatMul result at `r,c`.
|
||||
// No-op for the bottom-most tile where kRow >= kNumRows.
|
||||
// Loads *two* row vectors from A via `Decompress2`, widens to f32, multiplies
|
||||
// element-wise with `kRegRows` x 2 row vectors from transposed B, and adds
|
||||
// them to `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a
|
||||
// subset of the terms of the dot products that make up the MatMul result at
|
||||
// `r,c`. No-op for the bottom-most rows whose `kRow >= kNumRows`.
|
||||
//
|
||||
// This approach is atypical because it requires a horizontal sum, for which we
|
||||
// introduce a fast and new(?) vector-length agnostic 'transpose', see
|
||||
|
|
@ -107,22 +90,24 @@ class BRow {
|
|||
// - `Decompress2` decompresses two vectors at a time;
|
||||
// - B is column-major, so unit-stride SIMD loads return a column, not values
|
||||
// from different columns, i.e. a row.
|
||||
// Both could be fixed in a packing stage, which is not implemented yet, and
|
||||
// might not be necessary otherwise. However, `ReorderWidenMulAccumulate` is
|
||||
// important for bf16 performance and incompatible with the conventional
|
||||
// approach, because its pairwise adds would add together unrelated terms.
|
||||
// By contrast, pairwise adds are fine when our C lanes are the terms of a
|
||||
// single dot product, which can be reordered or pre-reduced.
|
||||
// - `ReorderWidenMulAccumulate` is important for bf16 performance, but its
|
||||
// pairwise adds would add together unrelated terms.
|
||||
// The first two could be fixed in a packing stage, which is not implemented
|
||||
// yet, and might not be necessary otherwise. The third seems a fundamental
|
||||
// mismatch. However, pairwise adds are fine in our setting because C lanes are
|
||||
// the terms of a single dot product, which can be reordered or pre-reduced.
|
||||
template <size_t kRow, typename MatTA>
|
||||
class ALoadAccumulate {
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
|
||||
public:
|
||||
ALoadAccumulate(const Mat<const MatTA>& A, size_t row_ac, size_t batch_size)
|
||||
// A.cols * batch_size is the total number of elements, required for
|
||||
// PackedSpan::BoundsCheck.
|
||||
: A_(MakeSpan(A.ptr, A.ofs + A.cols * batch_size)),
|
||||
A_ofs_(A.Row(row_ac + kRow)) {}
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
// `First` and `Next` handle a single row of A, so the horizontal sums of
|
||||
// their `C0..3` are the (partial) dot products for 4 consecutive values in
|
||||
// one row of C.
|
||||
static_assert(kRegCols == 4);
|
||||
|
||||
ALoadAccumulate(const ConstMat<MatTA>& A, size_t row_ac)
|
||||
: A_(MakeSpan(A.ptr, A.ofs + A.Extents().Area())),
|
||||
A_ofs_(A.Row(HWY_MIN(row_ac + kRow, A.Extents().rows - 1))) {}
|
||||
|
||||
// First iteration, col_ab = 0: initialize C0..3 instead of updating them.
|
||||
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
|
||||
|
|
@ -161,20 +146,27 @@ class ALoadAccumulate {
|
|||
Decompress2(dm, A_, A_ofs_, a0, a1);
|
||||
|
||||
const DF df;
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::WidenMulPairwiseAdd(df, a0, b00);
|
||||
C1 = hn::WidenMulPairwiseAdd(df, a0, b10);
|
||||
C2 = hn::WidenMulPairwiseAdd(df, a0, b20);
|
||||
C3 = hn::WidenMulPairwiseAdd(df, a0, b30);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
// Native ReorderWidenMulAccumulate adds to C0..3 for free.
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
} else {
|
||||
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01));
|
||||
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11));
|
||||
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21));
|
||||
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -217,20 +209,31 @@ class ALoadAccumulate {
|
|||
Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
|
||||
|
||||
const DF df;
|
||||
hn::Vec<DF> unused_sum1 = hn::Zero(df);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
// Native ReorderWidenMulAccumulate adds to C0..3 for free.
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
} else {
|
||||
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a0, b00));
|
||||
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a0, b10));
|
||||
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a0, b20));
|
||||
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a0, b30));
|
||||
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01));
|
||||
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11));
|
||||
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21));
|
||||
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -356,116 +359,113 @@ class AddHorizontalSums {
|
|||
// Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a
|
||||
// *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c).
|
||||
// TODO: loop over sections instead of full rows and accumulate into `tile_c`.
|
||||
// `buf` is 16 vectors of thread-local storage.
|
||||
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulTile(const size_t batch_size, const Mat<const MatTA>& A,
|
||||
const Mat<const MatTB>& B, const size_t row_ac,
|
||||
const size_t row_b_col_c, const float scale,
|
||||
const float* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT buf, const Mat<float>& C) {
|
||||
// For 'decompressing' A and B into BF16 or float.
|
||||
const hn::ScalableTag<MulT> dm;
|
||||
using VM = hn::Vec<decltype(dm)>;
|
||||
const size_t NM = hn::Lanes(dm);
|
||||
HWY_INLINE void MatMulTile(const ConstMat<MatTA>& A, const size_t row_ac,
|
||||
const ConstMat<MatTB>& B, const size_t row_b_col_c,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT buf, const RowPtr<float>& C) {
|
||||
// Decompress A and B to which type, which will then be widened to f32,
|
||||
// multiplied, added once into f32, then promoted to f64 and accumulated.
|
||||
// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are
|
||||
// more efficient than f32 * f32 + f32 because they process twice as many
|
||||
// lanes at a time. If available, we definitely want to use them. Otherwise,
|
||||
// bf16 is still worthwhile if A (activations) are bf16: SFP weights are
|
||||
// cheaper to decode to bf16, relative to the minor extra cost of promoting
|
||||
// bf16 when multiplying. However, if A is f32, demoting to bf16 can be
|
||||
// expensive unless we also have native bf16 dot.
|
||||
using Raw = hwy::If<HWY_NATIVE_DOT_BF16 || !IsF32<MatTA>(), BF16, float>;
|
||||
const hn::ScalableTag<Raw> dr;
|
||||
using VR = hn::Vec<decltype(dr)>;
|
||||
const size_t NR = hn::Lanes(dr);
|
||||
|
||||
const Range1D cols_ab(0, A.Extents().cols);
|
||||
HWY_DASSERT(row_ac + kNumRows <= A.Extents().rows);
|
||||
HWY_DASSERT(row_b_col_c + kNumRows <= B.Extents().rows);
|
||||
HWY_DASSERT(cols_ab.end() % (2 * NR) == 0);
|
||||
|
||||
static_assert(kRegRows == 4);
|
||||
const BRow<0, MatTB> b_row0(B, row_b_col_c, C.cols);
|
||||
const BRow<1, MatTB> b_row1(B, row_b_col_c, C.cols);
|
||||
const BRow<2, MatTB> b_row2(B, row_b_col_c, C.cols);
|
||||
const BRow<3, MatTB> b_row3(B, row_b_col_c, C.cols);
|
||||
const BRow<0, MatTB> b_row0(B, row_b_col_c);
|
||||
const BRow<1, MatTB> b_row1(B, row_b_col_c);
|
||||
const BRow<2, MatTB> b_row2(B, row_b_col_c);
|
||||
const BRow<3, MatTB> b_row3(B, row_b_col_c);
|
||||
|
||||
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac);
|
||||
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac);
|
||||
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac);
|
||||
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac);
|
||||
|
||||
const hn::Repartition<float, decltype(dm)> df;
|
||||
const hn::Repartition<float, decltype(dr)> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF C00, C01, C02, C03;
|
||||
VF C10, C11, C12, C13;
|
||||
VF C20, C21, C22, C23;
|
||||
VF C30, C31, C32, C33;
|
||||
|
||||
size_t col_ab = cols_ab.begin();
|
||||
{ // First iteration initializes the `Crc` vectors.
|
||||
VM b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dm, /*col_ab=*/0, b00, b01);
|
||||
b_row1.Load2(dm, /*col_ab=*/0, b10, b11);
|
||||
b_row2.Load2(dm, /*col_ab=*/0, b20, b21);
|
||||
b_row3.Load2(dm, /*col_ab=*/0, b30, b31);
|
||||
VR b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dr, col_ab, b00, b01);
|
||||
b_row1.Load2(dr, col_ab, b10, b11);
|
||||
b_row2.Load2(dr, col_ab, b20, b21);
|
||||
b_row3.Load2(dr, col_ab, b30, b31);
|
||||
|
||||
a_row0.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
a_row0.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C00, C01, C02, C03);
|
||||
a_row1.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
a_row1.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C10, C11, C12, C13);
|
||||
a_row2.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
a_row2.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C20, C21, C22, C23);
|
||||
a_row3.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
a_row3.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C30, C31, C32, C33);
|
||||
col_ab += 2 * NR;
|
||||
}
|
||||
|
||||
// `2 * NM` per iteration because `Load2` returns two vectors.
|
||||
// `2 * NR` per iteration because `Load2` returns two vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 2 * NM; col_ab <= A.cols - 2 * NM; col_ab += 2 * NM) {
|
||||
VM b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dm, col_ab, b00, b01);
|
||||
b_row1.Load2(dm, col_ab, b10, b11);
|
||||
b_row2.Load2(dm, col_ab, b20, b21);
|
||||
b_row3.Load2(dm, col_ab, b30, b31);
|
||||
for (; col_ab < cols_ab.end(); col_ab += 2 * NR) {
|
||||
VR b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dr, col_ab, b00, b01);
|
||||
b_row1.Load2(dr, col_ab, b10, b11);
|
||||
b_row2.Load2(dr, col_ab, b20, b21);
|
||||
b_row3.Load2(dr, col_ab, b30, b31);
|
||||
|
||||
a_row0.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
a_row0.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C00, C01, C02, C03);
|
||||
a_row1.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
a_row1.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C10, C11, C12, C13);
|
||||
a_row2.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
a_row2.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C20, C21, C22, C23);
|
||||
a_row3.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
a_row3.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C30, C31, C32, C33);
|
||||
}
|
||||
|
||||
// TODO: hoist into outer loop.
|
||||
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c;
|
||||
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.stride);
|
||||
float* HWY_RESTRICT C_tile = C.Row(row_ac) + row_b_col_c;
|
||||
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.Stride());
|
||||
|
||||
AddHorizontalSums<kNumRows>()(df, scale, C00, C01, C02, C03, C10, C11, C12,
|
||||
C13, C20, C21, C22, C23, C30, C31, C32, C33,
|
||||
buf, C_tile, C.stride);
|
||||
buf, C_tile, C.Stride());
|
||||
}
|
||||
|
||||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||
//
|
||||
// `A` is a row-major matrix of shape `(batch_size, A.cols)`.
|
||||
// `B` is transposed; `B.cols`, which must match `A.cols`, denotes the number of
|
||||
// rows in the original B, and `C.cols` the number of columns in the original B.
|
||||
//
|
||||
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
||||
// values. When `A` and/or `B` are from CompressedArray, `scale` should be the
|
||||
// product of their `.scale()` values, otherwise 1.0f.
|
||||
//
|
||||
// If `kAdd` is true, the row-vector `add` is added to each row of `C`,
|
||||
// otherwise `add` is ignored and can be nullptr. A scale for `add` is not
|
||||
// supported, so make sure its scale is 1.
|
||||
//
|
||||
// `C` is a row-major matrix of size `(batch_size, C.cols)`.
|
||||
//
|
||||
// Updates 4x4 tiles of C in parallel using a work-stealing thread pool.
|
||||
// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k.
|
||||
// Must not be called concurrently with the same `env`.
|
||||
template <bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
|
||||
const Mat<const MatTB>& B, const float scale,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const Mat<float>& C) {
|
||||
HWY_NOINLINE void MatMulImpl(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<float>& C) {
|
||||
// PROFILER_ZONE("Matmul");
|
||||
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty());
|
||||
HWY_DASSERT(A.cols == B.cols);
|
||||
HWY_DASSERT(A.Extents().cols == B.Extents().cols);
|
||||
const size_t batch_size = A.Extents().rows;
|
||||
HWY_DASSERT(C.Cols() % kRegCols == 0);
|
||||
HWY_DASSERT(C.Stride() >= C.Cols());
|
||||
HWY_DASSERT(B.Extents().rows == C.Cols());
|
||||
|
||||
// Must be a multiple of two vectors because we Decompress2.
|
||||
HWY_DASSERT(A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0);
|
||||
HWY_DASSERT(C.cols % kRegCols == 0);
|
||||
const float scale = A.scale * B.scale;
|
||||
|
||||
// We currently write C directly, which touches more memory than fits in L3.
|
||||
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
|
||||
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
|
||||
const size_t tilesX = C.cols / kRegCols;
|
||||
const size_t tilesX = C.Cols() / kRegCols;
|
||||
|
||||
env.Pool().Run(
|
||||
0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR {
|
||||
|
|
@ -481,24 +481,45 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
|
|||
HWY_DASSERT(num_rows != 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
MatMulTile<1, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
MatMulTile<1, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
|
||||
break;
|
||||
case 2:
|
||||
MatMulTile<2, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
MatMulTile<2, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
|
||||
break;
|
||||
case 3:
|
||||
MatMulTile<3, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
MatMulTile<3, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
|
||||
break;
|
||||
default:
|
||||
MatMulTile<4, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
MatMulTile<4, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||
//
|
||||
// `A` is a row-major matrix and `B` is transposed. Its `B.Extents().cols`,
|
||||
// which must match `A.Extents().cols`, is the number of rows in the original B.
|
||||
//
|
||||
// If `add` is non-null, the row-vector `add` is added to each row of `C`.
|
||||
// A scale for `add` is not supported, so make sure its scale is 1.
|
||||
//
|
||||
// `C` is a row-major matrix of size `(A.rows, C.Cols())` with support for
|
||||
// arbitrary strides.
|
||||
//
|
||||
// Updates 4x4 tiles of C in parallel using a work-stealing thread pool.
|
||||
// Typically `A.rows` is 1..512, `A.Extents().cols` and `B.Extents().rows` are
|
||||
// 3k or 24k. Must not be called concurrently with the same `env`.
|
||||
template <typename MatTA, typename MatTB>
|
||||
HWY_NOINLINE void MatMul(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<float>& C) {
|
||||
if (add) {
|
||||
MatMulImpl<true>(A, B, add, env, C);
|
||||
} else {
|
||||
MatMulImpl<false>(A, B, nullptr, env, C);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
55
ops/matmul.h
55
ops/matmul.h
|
|
@ -19,73 +19,22 @@
|
|||
#include <stddef.h>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "hwy/per_target.h" // VectorBytes
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Bundles ptr/size/stride arguments to simplify MatMul call sites. T can be
|
||||
// const or non-const. Create via ConstMat/MutableMat.
|
||||
// TODO(rays): Replace with MatPtr and get rid of stride, which is only != cols
|
||||
// in one place.
|
||||
template <typename T>
|
||||
struct Mat {
|
||||
bool NotEmpty() const {
|
||||
return ptr != nullptr && cols != 0 && stride >= cols;
|
||||
}
|
||||
size_t Row(size_t r) const { return ofs + stride * r; }
|
||||
|
||||
T* HWY_RESTRICT ptr;
|
||||
size_t cols;
|
||||
|
||||
// elements between rows, which is typically the same as `cols`.
|
||||
size_t stride;
|
||||
|
||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||
// pointer arithmetic.
|
||||
size_t ofs;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
|
||||
size_t ofs = 0) {
|
||||
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols, size_t stride,
|
||||
size_t ofs = 0) {
|
||||
return Mat<const T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<const T> ConstMat(Mat<T> mat) {
|
||||
return ConstMat(mat.ptr, mat.cols, mat.stride, mat.ofs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols) {
|
||||
return MutableMat(ptr, cols, cols);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols) {
|
||||
return ConstMat(ptr, cols, cols);
|
||||
}
|
||||
|
||||
// Allocations and threads, shared across MatMul calls.
|
||||
class MatMulEnv {
|
||||
public:
|
||||
MatMulEnv() : pools_(nullptr) {}
|
||||
explicit MatMulEnv(NestedPools& pools) : pools_(&pools) {
|
||||
const size_t N = hwy::VectorBytes() / sizeof(float);
|
||||
buf_ = RowVectorBatch<float>(pools.MaxWorkers(), 16 * N);
|
||||
buf_ = RowVectorBatch<float>(Extents2D(pools.MaxWorkers(), 16 * N));
|
||||
}
|
||||
|
||||
RowVectorBatch<float>& Buf() { return buf_; }
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
|
||||
#include "compression/compress.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -55,19 +56,23 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||
|
||||
template <typename MatT>
|
||||
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
|
||||
|
||||
// Generates inputs: deterministic, within max SfpStream range.
|
||||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
class MatPtr = std::unique_ptr<MatStorageT<MatT>>>
|
||||
MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) {
|
||||
template <typename MatT>
|
||||
MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
auto mat = std::make_unique<MatStorageT<MatT>>("test", kRows, kCols);
|
||||
auto mat =
|
||||
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
HWY_ASSERT(content);
|
||||
const float scale = SfpStream::kMax / (mat->NumElements() + offset);
|
||||
pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kCols; j++) {
|
||||
content[i * kCols + j] =
|
||||
static_cast<float>((i * kCols + j + offset) * scale);
|
||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
content[r * extents.cols + c] =
|
||||
static_cast<float>(r * extents.cols + c) * scale;
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -76,185 +81,173 @@ MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) {
|
|||
return mat;
|
||||
}
|
||||
|
||||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
class MatPtr = std::unique_ptr<MatStorageT<MatT>>>
|
||||
MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) {
|
||||
// extents describes the transposed matrix.
|
||||
template <typename MatT>
|
||||
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
MatPtr mat = std::make_unique<MatStorageT<MatT>>("test", kCols, kRows);
|
||||
auto mat =
|
||||
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
const float scale = SfpStream::kMax / (mat->NumElements() + offset);
|
||||
pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kCols; j++) {
|
||||
content[j * kRows + i] =
|
||||
static_cast<float>((i * kCols + j + offset) * scale);
|
||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
content[r * extents.cols + c] =
|
||||
static_cast<float>(c * extents.rows + r) * scale;
|
||||
}
|
||||
});
|
||||
|
||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||
// Arbitrary value, different from 1, must match GenerateMatHeap.
|
||||
// Arbitrary value, different from 1, must match GenerateMat.
|
||||
mat->set_scale(0.6f);
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
class MatPtr = std::unique_ptr<MatStorageT<MatT>>>
|
||||
MatPtr GenerateZeroMat(hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
auto mat = std::make_unique<MatStorageT<MatT>>("Array", kRows, kCols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
HWY_ASSERT(content);
|
||||
|
||||
pool.Run(0, kRows, [&](const size_t i, size_t thread) {
|
||||
hwy::ZeroBytes(&content[i * kCols], kCols * sizeof(content[0]));
|
||||
});
|
||||
|
||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||
mat->set_scale(1.2f); // Arbitrary value, different from 1.
|
||||
return mat;
|
||||
}
|
||||
|
||||
// Returns 1-norm, used for estimating tolerable numerical differences.
|
||||
double MaxColAbsSum(const float* HWY_RESTRICT a, size_t rows, size_t cols) {
|
||||
double MaxColAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
|
||||
double max_col_abs_sum = 0.0;
|
||||
for (size_t c = 0; c < cols; c++) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
double col_abs_sum = 0.0;
|
||||
for (size_t r = 0; r < rows; r++) {
|
||||
col_abs_sum += hwy::ScalarAbs(a[r * cols + c]);
|
||||
for (size_t r = 0; r < extents.rows; r++) {
|
||||
col_abs_sum += hwy::ScalarAbs(a[r * extents.cols + c]);
|
||||
}
|
||||
max_col_abs_sum = HWY_MAX(max_col_abs_sum, col_abs_sum);
|
||||
}
|
||||
return max_col_abs_sum;
|
||||
}
|
||||
|
||||
// B is already transposed.
|
||||
template <typename MatTA, typename MatTB>
|
||||
void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
|
||||
const MatTA* HWY_RESTRICT pa,
|
||||
const MatTB* HWY_RESTRICT pb_trans,
|
||||
const float* HWY_RESTRICT expected_c,
|
||||
const float* HWY_RESTRICT actual_c) {
|
||||
void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||
const RowPtrF& C_slow, const RowPtrF& C) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t num_a = rows_ac * cols_ab;
|
||||
const size_t num_b = cols_c_rows_b * cols_ab;
|
||||
const size_t num_a = A.extents.Area();
|
||||
const size_t num_b = B.extents.Area();
|
||||
HWY_ASSERT(num_a % hn::Lanes(df) == 0); // for DecompressAndZeroPad
|
||||
HWY_ASSERT(num_b % hn::Lanes(df) == 0); // for DecompressAndZeroPad
|
||||
const size_t num_c = rows_ac * cols_c_rows_b;
|
||||
FloatPtr a = hwy::AllocateAligned<float>(num_a);
|
||||
FloatPtr b_trans = hwy::AllocateAligned<float>(num_b);
|
||||
HWY_ASSERT(a && b_trans);
|
||||
DecompressAndZeroPad(df, MakeSpan(pa, num_a), 0, a.get(), num_a);
|
||||
DecompressAndZeroPad(df, MakeSpan(pb_trans, num_b), 0, b_trans.get(), num_b);
|
||||
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
||||
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
|
||||
DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b);
|
||||
|
||||
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
|
||||
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
|
||||
const double norm = MaxColAbsSum(a.get(), A.Extents()) *
|
||||
MaxColAbsSum(b_trans.get(), B.Extents());
|
||||
// Dot(float,BF16) rounds both to BF16.
|
||||
using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>;
|
||||
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>());
|
||||
const double tolerance = 200.0 * norm * epsilon;
|
||||
|
||||
for (size_t idx = 0; idx < num_c; idx++) {
|
||||
const double expected_value = expected_c[idx];
|
||||
const double actual_value = actual_c[idx];
|
||||
for (size_t r = 0; r < A.extents.rows; r++) {
|
||||
const float* expected_row = C_slow.Row(r);
|
||||
const float* actual_row = C.Row(r);
|
||||
for (size_t c = 0; c < B.extents.rows; c++) {
|
||||
const double expected_value = static_cast<double>(expected_row[c]);
|
||||
const double actual_value = static_cast<double>(actual_row[c]);
|
||||
|
||||
if (!(expected_value - tolerance <= actual_value &&
|
||||
actual_value <= expected_value + tolerance)) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"expected[%lu]: %f, actual[%lu]: %f, norm %f eps %E tolerance %f\n",
|
||||
idx, expected_value, idx, actual_value, norm, epsilon, tolerance);
|
||||
HWY_ASSERT(0);
|
||||
if (!(expected_value - tolerance <= actual_value &&
|
||||
actual_value <= expected_value + tolerance)) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"(%zu,%zu): expected %f, actual %f, norm %f eps %E tolerance %f\n",
|
||||
r, c, expected_value, actual_value, norm, epsilon, tolerance);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// B is already transposed.
|
||||
template <typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
|
||||
const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b_trans, const float scale,
|
||||
HWY_INLINE void MatMulSlow(const ConstMat<MatTA> A, const ConstMat<MatTB> B,
|
||||
const float* HWY_RESTRICT add_row, MatMulEnv& env,
|
||||
float* HWY_RESTRICT out) {
|
||||
const RowPtrF& C) {
|
||||
// MatTA can be any Packed except NuqStream because it uses pointer
|
||||
// arithmetic, because it is the second argument to Dot, which does not
|
||||
// support a v_ofs.
|
||||
static_assert(sizeof(MatTA) >= sizeof(BF16), "A matrix must be BF16/f32");
|
||||
const float scale = A.scale * B.scale;
|
||||
|
||||
const hn::ScalableTag<float> df; // lane type is ignored
|
||||
const PackedSpan<const MatTB> b_span =
|
||||
MakeSpan(b_trans, cols_a_rows_b * cols_bc);
|
||||
MakeSpan(B.ptr, B.ofs + B.extents.Area());
|
||||
const Extents2D C_extents(A.extents.rows, C.Cols());
|
||||
|
||||
StaticPartitionRowsAndCols(
|
||||
env.Pools(), rows_ac, cols_bc, sizeof(MatTB),
|
||||
[&](size_t /*node*/, hwy::ThreadPool& pool,
|
||||
const size_t /*worker_offset*/, const size_t row_begin,
|
||||
const size_t row_end, const size_t col_begin, const size_t col_end) {
|
||||
pool.Run(row_begin, row_end,
|
||||
[&](const uint64_t row, size_t /*thread*/) {
|
||||
for (size_t col = col_begin; col < col_end; ++col) {
|
||||
const float add = add_row ? add_row[col] : 0.0f;
|
||||
out[row * cols_bc + col] =
|
||||
scale * Dot(df, b_span, col * cols_a_rows_b,
|
||||
a + row * cols_a_rows_b, cols_a_rows_b) +
|
||||
add;
|
||||
}
|
||||
});
|
||||
env.Pools(), C_extents, sizeof(MatTB),
|
||||
[&](const Range2D& C_range, const TaskLocation& loc) {
|
||||
loc.cluster.Run(
|
||||
C_range.rows.begin(), C_range.rows.end(),
|
||||
[&](const uint64_t row, size_t /*thread*/) {
|
||||
float* HWY_RESTRICT C_row = C.Row(row);
|
||||
for (size_t row_b_col_c : C_range.cols) {
|
||||
const float add = add_row ? add_row[row_b_col_c] : 0.0f;
|
||||
C_row[row_b_col_c] =
|
||||
add + scale * Dot(df, b_span, row_b_col_c * B.extents.cols,
|
||||
A.ptr + A.Row(row), A.extents.cols);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
|
||||
size_t cols_bc, double elapsed) {
|
||||
const size_t num_b = cols_a_rows_b * cols_bc;
|
||||
void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
||||
const Extents2D& B_extents, double elapsed) {
|
||||
const size_t num_b = B_extents.Area();
|
||||
// 2x because of FMA.
|
||||
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
|
||||
elapsed, 2 * 1E-9 * rows_ac * num_b / elapsed);
|
||||
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
|
||||
}
|
||||
|
||||
template <size_t kRowsAC, size_t kColsARowsB, size_t kColsBC, bool kAdd,
|
||||
typename MatTA, typename MatTB = MatTA>
|
||||
void TestMatMul(MatMulEnv& env) {
|
||||
template <typename MatTA, typename MatTB = MatTA>
|
||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||
MatMulEnv& env) {
|
||||
hwy::ThreadPool& pool = env.Pool();
|
||||
const bool want_bench = kColsBC > 2000; // avoid spam for small matrices
|
||||
const bool want_bench = cols_bc > 2000; // avoid spam for small matrices
|
||||
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
||||
kRowsAC, kColsARowsB, kColsBC, kAdd, TypeName<MatTA>(),
|
||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
|
||||
TypeName<MatTB>());
|
||||
|
||||
std::unique_ptr<MatStorageT<MatTA>> a =
|
||||
GenerateMat<MatTA, kRowsAC, kColsARowsB>(0, pool);
|
||||
std::unique_ptr<MatStorageT<MatTB>> b_trans =
|
||||
GenerateTransposedMat<MatTB, kColsARowsB, kColsBC>(0, pool);
|
||||
FloatPtr c = hwy::AllocateAligned<float>(kRowsAC * kColsBC);
|
||||
HWY_ASSERT(c);
|
||||
const Extents2D A_extents(rows_ac, cols_a_rows_b);
|
||||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||
const Extents2D C_extents(rows_ac, cols_bc);
|
||||
|
||||
const float scale = a->scale() * b_trans->scale();
|
||||
std::unique_ptr<MatStorageT<float>> add;
|
||||
if (kAdd) {
|
||||
add = GenerateMat<float, 1, kColsBC>(0, pool);
|
||||
add->set_scale(1.0f);
|
||||
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
|
||||
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
|
||||
RowVectorBatch<float> c_slow_batch(C_extents);
|
||||
RowVectorBatch<float> c_batch(C_extents);
|
||||
HWY_ASSERT(a && b_trans);
|
||||
|
||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||
if (add) {
|
||||
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
|
||||
HWY_ASSERT(add_storage);
|
||||
add_storage->set_scale(1.0f);
|
||||
}
|
||||
|
||||
std::unique_ptr<MatStorageT<float>> c_slow =
|
||||
GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
|
||||
const auto A = ConstMatFromWeights(*a);
|
||||
const auto B = ConstMatFromWeights(*b_trans);
|
||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||
const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch);
|
||||
const RowPtrF C = RowPtrFromBatch(c_batch);
|
||||
|
||||
const double start_slow = hwy::platform::Now();
|
||||
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
|
||||
kAdd ? add->data() : nullptr, env, c_slow->data());
|
||||
MatMulSlow(A, B, add_row, env, C_slow);
|
||||
if (want_bench) {
|
||||
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,
|
||||
PrintSpeed("MatMulSlow", A_extents, B_extents,
|
||||
hwy::platform::Now() - start_slow);
|
||||
}
|
||||
|
||||
double min_elapsed = hwy::HighestValue<double>();
|
||||
for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) {
|
||||
const double start_tiled = hwy::platform::Now();
|
||||
MatMul<kAdd>(kRowsAC, ConstMat(a->data(), kColsARowsB),
|
||||
ConstMat(b_trans->data(), kColsARowsB), scale,
|
||||
kAdd ? add->data_scale1() : nullptr, env,
|
||||
MutableMat(c.get(), kColsBC));
|
||||
MatMul(A, B, add_row, env, C);
|
||||
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
|
||||
}
|
||||
if (want_bench) {
|
||||
PrintSpeed("MatMul", kRowsAC, kColsARowsB, kColsBC, min_elapsed);
|
||||
PrintSpeed("MatMul", A_extents, B_extents, min_elapsed);
|
||||
}
|
||||
|
||||
AssertClose(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(),
|
||||
c_slow->data(), c.get());
|
||||
AssertClose(A, B, C_slow, C);
|
||||
}
|
||||
|
||||
void TestAllMatMul() {
|
||||
|
|
@ -264,8 +257,9 @@ void TestAllMatMul() {
|
|||
return;
|
||||
}
|
||||
|
||||
NestedPools pools(4, /*pin=*/1);
|
||||
pools.StartSpinning();
|
||||
NestedPools pools(4, /*pin=*/Tristate::kDefault);
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
pools.MaybeStartSpinning(use_spinning);
|
||||
Allocator::Init(pools.Topology());
|
||||
MatMulEnv env(pools);
|
||||
|
||||
|
|
@ -273,52 +267,54 @@ void TestAllMatMul() {
|
|||
using SFP = SfpStream;
|
||||
|
||||
// large-scale test: batch_size=128 is better than 64 or 256 for SKX.
|
||||
TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<128, 3072, 24576, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<1, 24576, 3072, /*kAdd=*/false, F32, F32>(env);
|
||||
TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, F32>(env);
|
||||
// TestMatMul<F32, SFP>(128, 24576, 3072, /*add=*/false, env);
|
||||
// TestMatMul<F32, SFP>(128, 3072, 24576, /*add=*/false, env);
|
||||
TestMatMul<F32, F32>(1, 24576, 3072, /*add=*/false, env);
|
||||
TestMatMul<F32, F32>(1, 3072, 24576, /*add=*/false, env);
|
||||
TestMatMul<F32, SFP>(1, 24576, 3072, /*add=*/false, env);
|
||||
TestMatMul<F32, SFP>(1, 3072, 24576, /*add=*/false, env);
|
||||
|
||||
// medium-sized square test - temporarily disabled for faster testing.
|
||||
if constexpr (false) {
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<F32>(512, 512, 512, /*add=*/false, env);
|
||||
TestMatMul<BF16>(512, 512, 512, /*add=*/true, env);
|
||||
TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env);
|
||||
TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env);
|
||||
TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env);
|
||||
TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env);
|
||||
}
|
||||
|
||||
// minimal non-square test. kColsARowsB must be at least 2 vectors.
|
||||
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<F32>(35, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32>(4, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<BF16>(4, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<F32>(3, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16>(3, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32>(2, 128, 64, /*add=*/true, env);
|
||||
TestMatMul<BF16>(2, 128, 64, /*add=*/false, env);
|
||||
TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env);
|
||||
TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env);
|
||||
TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env);
|
||||
TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env);
|
||||
TestMatMul<F32>(1, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16>(1, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -389,7 +389,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
void TestRopeAndMulBy() {
|
||||
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
|
||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
RowVectorBatch<float> x(1, dim_qkv);
|
||||
RowVectorBatch<float> x(Extents2D(1, dim_qkv));
|
||||
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ cc_library(
|
|||
deps = [
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -37,11 +38,11 @@ cc_test(
|
|||
"no_tap",
|
||||
],
|
||||
deps = [
|
||||
"@googletest//:gtest_main",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:benchmark_helper",
|
||||
"//:common",
|
||||
"//:gemma_lib",
|
||||
"//:tokenizer",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "paligemma/image.h"
|
||||
#include "compression/io.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -24,11 +24,15 @@
|
|||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/aligned_allocator.h" // hwy::Span
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
|
@ -95,12 +99,12 @@ bool Image::ReadPPM(const std::string& filename) {
|
|||
std::cerr << filename << " does not exist\n";
|
||||
return false;
|
||||
}
|
||||
auto content = ReadFileToString(path);
|
||||
const std::string content = ReadFileToString(path);
|
||||
return ReadPPM(hwy::Span<const char>(content.data(), content.size()));
|
||||
}
|
||||
|
||||
bool Image::ReadPPM(const hwy::Span<const char>& buf) {
|
||||
auto pos = CheckP6Format(buf.cbegin(), buf.cend());
|
||||
const char* pos = CheckP6Format(buf.cbegin(), buf.cend());
|
||||
if (!pos) {
|
||||
std::cerr << "We only support binary PPM (P6)\n";
|
||||
return false;
|
||||
|
|
@ -134,8 +138,8 @@ bool Image::ReadPPM(const hwy::Span<const char>& buf) {
|
|||
return false;
|
||||
}
|
||||
++pos;
|
||||
auto data_size = width * height * 3;
|
||||
if (buf.cend() - pos < data_size) {
|
||||
const size_t data_size = width * height * 3;
|
||||
if (buf.cend() - pos < static_cast<ptrdiff_t>(data_size)) {
|
||||
std::cerr << "Insufficient data remaining\n";
|
||||
return false;
|
||||
}
|
||||
|
|
@ -149,6 +153,27 @@ bool Image::ReadPPM(const hwy::Span<const char>& buf) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void Image::Set(int width, int height, const float* data) {
|
||||
width_ = width;
|
||||
height_ = height;
|
||||
int num_elements = width * height * 3;
|
||||
data_.resize(num_elements);
|
||||
data_.assign(data, data + num_elements);
|
||||
float min_value = std::numeric_limits<float>::infinity();
|
||||
float max_value = -std::numeric_limits<float>::infinity();
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
if (data_[i] < min_value) min_value = data_[i];
|
||||
if (data_[i] > max_value) max_value = data_[i];
|
||||
}
|
||||
// -> out_min + (value - in_min) * (out_max - out_min) / (in_max - in_min)
|
||||
float in_range = max_value - min_value;
|
||||
if (in_range == 0.0f) in_range = 1.0f;
|
||||
float scale = 2.0f / in_range;
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
data_[i] = (data_[i] - min_value) * scale - 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
void Image::Resize() {
|
||||
int new_width = 224;
|
||||
int new_height = kImageSize;
|
||||
|
|
@ -190,23 +215,24 @@ bool Image::WriteBinary(const std::string& filename) const {
|
|||
// We want the N-th patch (of 256) of size kPatchSize x kPatchSize x 3.
|
||||
// Patches are numbered in usual "pixel-order".
|
||||
void Image::GetPatch(size_t patch_num, float* patch) const {
|
||||
PROFILER_FUNC;
|
||||
constexpr size_t kDataSize = kImageSize * kImageSize * 3;
|
||||
HWY_ASSERT(size() == kDataSize);
|
||||
constexpr size_t kPatchDataSize = kPatchSize * kPatchSize * 3;
|
||||
int i_offs = patch_num / kNumPatches;
|
||||
int j_offs = patch_num % kNumPatches;
|
||||
size_t i_offs = patch_num / kNumPatches;
|
||||
size_t j_offs = patch_num % kNumPatches;
|
||||
HWY_ASSERT(0 <= i_offs && i_offs < kNumPatches);
|
||||
HWY_ASSERT(0 <= j_offs && j_offs < kNumPatches);
|
||||
i_offs *= kPatchSize;
|
||||
j_offs *= kPatchSize;
|
||||
// This can be made faster, but let's first see whether it matters.
|
||||
const float* image_data = data();
|
||||
for (int i = 0; i < kPatchSize; ++i) {
|
||||
for (int j = 0; j < kPatchSize; ++j) {
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
const int patch_index = (i * kPatchSize + j) * 3 + k;
|
||||
for (size_t i = 0; i < kPatchSize; ++i) {
|
||||
for (size_t j = 0; j < kPatchSize; ++j) {
|
||||
for (size_t k = 0; k < 3; ++k) {
|
||||
const size_t patch_index = (i * kPatchSize + j) * 3 + k;
|
||||
HWY_ASSERT(patch_index < kPatchDataSize);
|
||||
const int image_index =
|
||||
const size_t image_index =
|
||||
((i + i_offs) * kImageSize + (j + j_offs)) * 3 + k;
|
||||
HWY_ASSERT(image_index < kDataSize);
|
||||
patch[patch_index] = image_data[image_index];
|
||||
|
|
@ -214,4 +240,5 @@ void Image::GetPatch(size_t patch_num, float* patch) const {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@ class Image {
|
|||
// Reads PPM format (P6, binary) data from a hwy::Span, normalizes to [-1, 1].
|
||||
// Returns true on success.
|
||||
bool ReadPPM(const hwy::Span<const char>& buf);
|
||||
// Sets the image content to the given data. The data is copied and normalized
|
||||
// to [-1, 1]. The data is expected to be of size width * height * 3.
|
||||
void Set(int width, int height, const float* data);
|
||||
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
|
||||
// be better).
|
||||
void Resize();
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
|
|
@ -44,19 +44,20 @@ class PaliGemmaTest : public ::testing::Test {
|
|||
std::string GemmaReply(const std::string& prompt_text) const;
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions);
|
||||
|
||||
std::unique_ptr<ImageTokens> image_tokens_;
|
||||
ImageTokens image_tokens_;
|
||||
};
|
||||
|
||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
Gemma& model = *(s_env->GetModel());
|
||||
image_tokens_ = std::make_unique<ImageTokens>(256, 2048);
|
||||
image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
|
||||
model.GetModelConfig().model_dim));
|
||||
Image image;
|
||||
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
|
||||
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
|
||||
model.GenerateImageTokens(runtime_config, image, *image_tokens_);
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens_);
|
||||
}
|
||||
|
||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||
|
|
@ -65,7 +66,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
|||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
.verbosity = 0,
|
||||
.gen = &s_env->MutableGen()};
|
||||
runtime_config.image_tokens = image_tokens_.get();
|
||||
runtime_config.image_tokens = &image_tokens_;
|
||||
size_t abs_pos = 0;
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
|
||||
|
|
@ -77,7 +78,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
|||
return true;
|
||||
};
|
||||
runtime_config.stream_token = stream_token,
|
||||
tokens.insert(tokens.begin(), image_tokens_->BatchSize(), 0);
|
||||
tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0);
|
||||
size_t num_tokens = tokens.size();
|
||||
size_t prefix_end = num_tokens;
|
||||
runtime_config.prefill_tbatch_size = num_tokens;
|
||||
|
|
|
|||
|
|
@ -162,20 +162,19 @@ static void BindMemory(void* ptr, size_t bytes, size_t node) {
|
|||
static void BindMemory(void*, size_t, size_t) {}
|
||||
#endif // GEMMA_NUMA && HWY_OS_LINUX
|
||||
|
||||
void BindTensor(NestedPools& nested, size_t rows, size_t cols,
|
||||
void BindTensor(NestedPools& nested, const Extents2D& extents,
|
||||
size_t bytes_per_col, void* ptr) {
|
||||
if (!Allocator::UseNUMA()) return;
|
||||
uint8_t* p8 = static_cast<uint8_t*>(ptr);
|
||||
const size_t bytes_per_row = cols * bytes_per_col;
|
||||
const size_t bytes_per_row = extents.cols * bytes_per_col;
|
||||
StaticPartitionRowsAndCols(
|
||||
nested, rows, cols, bytes_per_col,
|
||||
[&](size_t node, hwy::ThreadPool&, const size_t /*worker_offset*/,
|
||||
const size_t row_begin, const size_t row_end, const size_t col_begin,
|
||||
const size_t col_end) {
|
||||
for (size_t row = row_begin; row < row_end; ++row) {
|
||||
uint8_t* slice = p8 + row * bytes_per_row + col_begin * bytes_per_col;
|
||||
const size_t slice_size = (col_end - col_begin) * bytes_per_col;
|
||||
BindMemory(slice, slice_size, node);
|
||||
nested, extents, bytes_per_col,
|
||||
[&](const Range2D& r, const TaskLocation& loc) {
|
||||
for (size_t row : r.rows) {
|
||||
uint8_t* slice =
|
||||
p8 + row * bytes_per_row + r.cols.begin() * bytes_per_col;
|
||||
const size_t slice_size = r.cols.Num() * bytes_per_col;
|
||||
BindMemory(slice, slice_size, loc.node);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cstdlib> // std::aligned_alloc
|
||||
#include <cstdlib> // std::aligned_alloc / _aligned_malloc
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -52,49 +53,6 @@ ByteStorageT AllocateSizeof() {
|
|||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||
}
|
||||
|
||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||
// This can be seen as a (batch_size x len) matrix.
|
||||
template <typename T>
|
||||
class RowVectorBatch {
|
||||
public:
|
||||
// Default ctor for Activations ctor.
|
||||
RowVectorBatch() : batch_size_(0), len_(0) {}
|
||||
// Main ctor, called from Activations::Allocate.
|
||||
RowVectorBatch(size_t batch_size, size_t len)
|
||||
: batch_size_(batch_size), len_(len) {
|
||||
mem_ = hwy::AllocateAligned<T>(batch_size * len);
|
||||
}
|
||||
|
||||
// Move-only
|
||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||
|
||||
size_t BatchSize() const { return batch_size_; }
|
||||
size_t Len() const { return len_; }
|
||||
|
||||
// Returns the given row vector of length `Len()`.
|
||||
T* Batch(size_t batch_idx) {
|
||||
HWY_DASSERT(batch_idx < batch_size_);
|
||||
return mem_.get() + batch_idx * len_;
|
||||
}
|
||||
const T* Batch(size_t batch_idx) const {
|
||||
HWY_DASSERT(batch_idx < batch_size_);
|
||||
return mem_.get() + batch_idx * len_;
|
||||
}
|
||||
|
||||
// For MatMul or other operations that process the entire batch at once.
|
||||
T* All() { return mem_.get(); }
|
||||
const T* Const() const { return mem_.get(); }
|
||||
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
|
||||
|
||||
private:
|
||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||
size_t batch_size_; // rows in the matrix
|
||||
size_t len_; // columns in the matrix = vector length
|
||||
};
|
||||
|
||||
// Stateful in order to know whether to bind to NUMA nodes. `Monostate` for
|
||||
// convenience - avoids passing around a reference.
|
||||
class Allocator {
|
||||
|
|
@ -140,15 +98,19 @@ class Allocator {
|
|||
}
|
||||
|
||||
// AlignedFreeUniquePtr has a deleter that can call an arbitrary `free`, but
|
||||
// with an extra opaque pointer, which we discard via this adapter.
|
||||
// with an extra opaque pointer, which we discard via `call_free`.
|
||||
#if defined(__ANDROID_API__) && __ANDROID_API__ < 28
|
||||
const auto call_free = [](void* ptr, void*) { std::free(ptr); };
|
||||
#if !defined(__ANDROID_API__) || __ANDROID_API__ >= 28
|
||||
T* p = static_cast<T*>(std::aligned_alloc(Alignment(), bytes));
|
||||
#else
|
||||
void* mem = nullptr;
|
||||
int err = posix_memalign(&mem, Alignment(), bytes);
|
||||
HWY_ASSERT(err == 0);
|
||||
T* p = static_cast<T*>(mem);
|
||||
#elif HWY_OS_WIN
|
||||
const auto call_free = [](void* ptr, void*) { _aligned_free(ptr); };
|
||||
T* p = static_cast<T*>(_aligned_malloc(bytes, Alignment()));
|
||||
#else
|
||||
const auto call_free = [](void* ptr, void*) { std::free(ptr); };
|
||||
T* p = static_cast<T*>(std::aligned_alloc(Alignment(), bytes));
|
||||
#endif
|
||||
return hwy::AlignedFreeUniquePtr<T[]>(
|
||||
p, hwy::AlignedFreer(call_free, nullptr));
|
||||
|
|
@ -163,10 +125,24 @@ class Allocator {
|
|||
static size_t alignment_;
|
||||
};
|
||||
|
||||
// For shorter arguments to the StaticPartitionRowsAndCols functor.
|
||||
struct TaskLocation {
|
||||
TaskLocation(size_t node, size_t package_idx, hwy::ThreadPool& cluster,
|
||||
size_t worker_offset)
|
||||
: node(node),
|
||||
package_idx(package_idx),
|
||||
cluster(cluster),
|
||||
worker_offset(worker_offset) {}
|
||||
size_t node;
|
||||
size_t package_idx;
|
||||
hwy::ThreadPool& cluster;
|
||||
const size_t worker_offset;
|
||||
};
|
||||
|
||||
// Used in MatMul and allocator.h. Defined here because it depends on
|
||||
// Allocator::Alignment().
|
||||
template <class Func>
|
||||
void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
|
||||
void StaticPartitionRowsAndCols(NestedPools& nested, Extents2D extents,
|
||||
size_t bytes_per_element, const Func& func) {
|
||||
// Both rows and cols must be a multiple of the alignment to avoid
|
||||
// touching remote pages.
|
||||
|
|
@ -179,14 +155,15 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
|
|||
hwy::ThreadPool& all_packages = nested.AllPackages();
|
||||
const size_t num_packages = all_packages.NumWorkers();
|
||||
const size_t cols_per_package =
|
||||
hwy::RoundUpTo(hwy::DivCeil(cols, num_packages), multiple);
|
||||
const size_t col_tasks = hwy::DivCeil(cols, cols_per_package);
|
||||
hwy::RoundUpTo(hwy::DivCeil(extents.cols, num_packages), multiple);
|
||||
const size_t col_tasks = hwy::DivCeil(extents.cols, cols_per_package);
|
||||
HWY_ASSERT(col_tasks <= num_packages);
|
||||
all_packages.Run(
|
||||
0, col_tasks, [&](uint64_t package_idx, size_t package_thread) {
|
||||
HWY_ASSERT(package_idx == package_thread); // one task per worker
|
||||
const size_t col_begin = package_idx * cols_per_package;
|
||||
const size_t col_end = HWY_MIN(col_begin + cols_per_package, cols);
|
||||
const Range1D col_range =
|
||||
MakeRange1D(col_begin, extents.cols, cols_per_package);
|
||||
|
||||
// Static partitioning of rows across the package's clusters. We assume
|
||||
// that row sharding is cheaper. In MatMul, results can indeed be
|
||||
|
|
@ -194,8 +171,8 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
|
|||
hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx);
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
const size_t rows_per_cluster =
|
||||
hwy::RoundUpTo(hwy::DivCeil(rows, num_clusters), multiple);
|
||||
const size_t row_tasks = hwy::DivCeil(rows, rows_per_cluster);
|
||||
hwy::RoundUpTo(hwy::DivCeil(extents.rows, num_clusters), multiple);
|
||||
const size_t row_tasks = hwy::DivCeil(extents.rows, rows_per_cluster);
|
||||
HWY_ASSERT(row_tasks <= num_clusters);
|
||||
all_clusters.Run(
|
||||
0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) {
|
||||
|
|
@ -213,11 +190,11 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
|
|||
nested.WorkerOffset(package_idx, cluster_idx);
|
||||
|
||||
const size_t row_begin = cluster_idx * rows_per_cluster;
|
||||
const size_t row_end =
|
||||
HWY_MIN(row_begin + rows_per_cluster, rows);
|
||||
const Range1D row_range =
|
||||
MakeRange1D(row_begin, extents.rows, rows_per_cluster);
|
||||
|
||||
func(node, cluster, worker_offset, row_begin, row_end, col_begin,
|
||||
col_end);
|
||||
func(Range2D(row_range, col_range),
|
||||
TaskLocation(node, package_idx, cluster, worker_offset));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
|||
10
util/app.h
10
util/app.h
|
|
@ -28,6 +28,7 @@
|
|||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h" // HWY_IS_ASAN
|
||||
|
||||
|
|
@ -59,7 +60,9 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
int verbosity;
|
||||
|
||||
size_t max_threads; // divided among the detected clusters
|
||||
int pin; // -1 = auto, 0 = no, 1 = yes
|
||||
Tristate pin; // pin threads?
|
||||
Tristate spin; // use spin waits?
|
||||
|
||||
// For BoundedSlice:
|
||||
size_t skip_packages;
|
||||
size_t max_packages;
|
||||
|
|
@ -81,7 +84,10 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
||||
visitor(max_threads, "num_threads", size_t{0},
|
||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
||||
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
visitor(pin, "pin", Tristate::kDefault,
|
||||
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
visitor(spin, "spin", Tristate::kDefault,
|
||||
"Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
// These can be used to partition CPU sockets/packages and their
|
||||
// clusters/CCXs across several program instances. The default is to use
|
||||
// all available resources.
|
||||
|
|
|
|||
32
util/args.h
32
util/args.h
|
|
@ -24,6 +24,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -62,6 +63,13 @@ class ArgsBase {
|
|||
}
|
||||
}
|
||||
|
||||
void operator()(const Tristate& t, const char* name,
|
||||
const Tristate& /*init*/, const char* /*help*/,
|
||||
int print_verbosity = 0) const {
|
||||
if (verbosity_ >= print_verbosity) {
|
||||
fprintf(stderr, "%-30s: %s\n", name, ToString(t));
|
||||
}
|
||||
}
|
||||
void operator()(const std::string& t, const char* name,
|
||||
const std::string& /*init*/, const char* /*help*/,
|
||||
int print_verbosity = 0) const {
|
||||
|
|
@ -127,13 +135,33 @@ class ArgsBase {
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool SetValue(const char* string, bool& t) {
|
||||
// Returns lower-cased string. Arg names are expected to be ASCII-only.
|
||||
static std::string ToLower(const char* string) {
|
||||
std::string value(string);
|
||||
// Lower-case. Arg names are expected to be ASCII-only.
|
||||
std::transform(value.begin(), value.end(), value.begin(), [](char c) {
|
||||
return 'A' <= c && c <= 'Z' ? c - ('Z' - 'z') : c;
|
||||
});
|
||||
return value;
|
||||
}
|
||||
|
||||
static bool SetValue(const char* string, Tristate& t) {
|
||||
const std::string value = ToLower(string);
|
||||
if (value == "true" || value == "on" || value == "1") {
|
||||
t = Tristate::kTrue;
|
||||
return true;
|
||||
} else if (value == "false" || value == "off" || value == "0") {
|
||||
t = Tristate::kFalse;
|
||||
return true;
|
||||
} else if (value == "default" || value == "auto" || value == "-1") {
|
||||
t = Tristate::kDefault;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool SetValue(const char* string, bool& t) {
|
||||
const std::string value = ToLower(string);
|
||||
if (value == "true" || value == "on" || value == "1") {
|
||||
t = true;
|
||||
return true;
|
||||
|
|
|
|||
205
util/basics.h
205
util/basics.h
|
|
@ -20,6 +20,7 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // HWY_IS_MSAN
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
|
|
@ -29,6 +30,21 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
||||
|
||||
static inline const char* ToString(Tristate t) {
|
||||
switch (t) {
|
||||
case Tristate::kFalse:
|
||||
return "false";
|
||||
case Tristate::kTrue:
|
||||
return "true";
|
||||
case Tristate::kDefault:
|
||||
return "default";
|
||||
}
|
||||
}
|
||||
|
||||
using BF16 = hwy::bfloat16_t;
|
||||
|
||||
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
|
||||
#if HWY_IS_MSAN
|
||||
__msan_check_mem_is_initialized(ptr, size);
|
||||
|
|
@ -44,6 +60,195 @@ struct TokenAndProb {
|
|||
float prob;
|
||||
};
|
||||
|
||||
// Entire size of a 2D array. By contrast, Range2D is a subrange.
|
||||
struct Extents2D {
|
||||
Extents2D() : rows(0), cols(0) {}
|
||||
Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
|
||||
HWY_DASSERT(rows != 0);
|
||||
HWY_DASSERT(cols != 0);
|
||||
}
|
||||
|
||||
size_t Area() const { return rows * cols; }
|
||||
|
||||
size_t rows;
|
||||
size_t cols;
|
||||
};
|
||||
|
||||
// Range2D consists of two Range1D.
|
||||
struct Range1D {
|
||||
Range1D(size_t begin, size_t end) : begin_(begin), end_(end) {
|
||||
HWY_DASSERT(begin < end);
|
||||
}
|
||||
size_t Num() const { return end_ - begin_; }
|
||||
|
||||
// Enable range-based for loops.
|
||||
class Iterator {
|
||||
public:
|
||||
Iterator(size_t i) : i_(i) {}
|
||||
|
||||
Iterator& operator++() {
|
||||
++i_;
|
||||
return *this;
|
||||
}
|
||||
bool operator!=(const Iterator& other) const { return i_ != other.i_; }
|
||||
size_t operator*() const { return i_; }
|
||||
// Enable using begin() directly as a size_t.
|
||||
operator size_t() const { return i_; }
|
||||
|
||||
private:
|
||||
size_t i_;
|
||||
};
|
||||
Iterator begin() const { return Iterator(begin_); }
|
||||
Iterator end() const { return Iterator(end_); }
|
||||
|
||||
const size_t begin_;
|
||||
const size_t end_;
|
||||
};
|
||||
|
||||
static inline Range1D MakeRange1D(size_t begin, size_t end, size_t max_size) {
|
||||
return Range1D(begin, HWY_MIN(begin + max_size, end));
|
||||
}
|
||||
|
||||
// In MatMul, the two axes are used independently, hence we do not define
|
||||
// Range2D as a top-left and extents.
|
||||
struct Range2D {
|
||||
Range2D(Range1D rows, Range1D cols) : rows(rows), cols(cols) {}
|
||||
const Range1D rows;
|
||||
const Range1D cols;
|
||||
};
|
||||
|
||||
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
||||
// it is always float and does not support compressed T, but does support an
|
||||
// arbitrary stride >= cols.
|
||||
template <typename T>
|
||||
class RowPtr {
|
||||
public:
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols)
|
||||
: row0_(row0), cols_(cols), stride_(cols) {}
|
||||
|
||||
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
||||
size_t Cols() const { return cols_; }
|
||||
|
||||
size_t Stride() const { return stride_; }
|
||||
void SetStride(size_t stride) {
|
||||
HWY_DASSERT(stride >= Cols());
|
||||
stride_ = stride;
|
||||
}
|
||||
|
||||
private:
|
||||
T* HWY_RESTRICT row0_;
|
||||
size_t stride_;
|
||||
size_t cols_;
|
||||
};
|
||||
|
||||
using RowPtrF = RowPtr<float>;
|
||||
|
||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
||||
// the memory.
|
||||
template <typename T>
|
||||
class RowVectorBatch {
|
||||
public:
|
||||
// Default ctor for Activations ctor.
|
||||
RowVectorBatch() = default;
|
||||
// Main ctor, called from Activations::Allocate.
|
||||
RowVectorBatch(Extents2D extents) : extents_(extents) {
|
||||
mem_ = hwy::AllocateAligned<T>(extents_.rows * extents_.cols);
|
||||
}
|
||||
|
||||
// Move-only
|
||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||
|
||||
size_t BatchSize() const { return extents_.rows; }
|
||||
size_t Cols() const { return extents_.cols; }
|
||||
Extents2D Extents() const { return extents_; }
|
||||
|
||||
// Returns the given row vector of length `Cols()`.
|
||||
T* Batch(size_t batch_idx) {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * Cols();
|
||||
}
|
||||
const T* Batch(size_t batch_idx) const {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * Cols();
|
||||
}
|
||||
|
||||
// For MatMul or other operations that process the entire batch at once.
|
||||
// TODO: remove once we only use Mat.
|
||||
T* All() { return mem_.get(); }
|
||||
const T* Const() const { return mem_.get(); }
|
||||
size_t NumBytes() const { return BatchSize() * Cols() * sizeof(T); }
|
||||
|
||||
private:
|
||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||
Extents2D extents_;
|
||||
};
|
||||
|
||||
// Used for the A and B arguments of `MatMul`, which are always const.
|
||||
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
|
||||
// `ofs` required for compressed T.
|
||||
template <typename T>
|
||||
struct ConstMat {
|
||||
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0)
|
||||
: ptr(ptr), extents(extents), ofs(ofs) {
|
||||
HWY_DASSERT(ptr != nullptr);
|
||||
}
|
||||
// TODO: support stride for page alignment.
|
||||
size_t Row(size_t r) const {
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (r >= extents.rows) {
|
||||
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
|
||||
}
|
||||
}
|
||||
return ofs + extents.cols * r;
|
||||
}
|
||||
|
||||
const Extents2D& Extents() const { return extents; }
|
||||
|
||||
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
||||
// subrange of the original rows starting at row 0.
|
||||
void ShrinkRows(size_t rows) {
|
||||
HWY_ASSERT(rows <= extents.rows);
|
||||
extents.rows = rows;
|
||||
}
|
||||
|
||||
const T* HWY_RESTRICT ptr;
|
||||
Extents2D extents;
|
||||
|
||||
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
||||
// values. MatFromWeights sets this from `MatPtr`.
|
||||
float scale = 1.0f;
|
||||
|
||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||
// pointer arithmetic.
|
||||
size_t ofs;
|
||||
};
|
||||
|
||||
// For deducing T.
|
||||
template <typename T>
|
||||
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
|
||||
size_t ofs = 0) {
|
||||
return ConstMat<T>(ptr, extents, ofs);
|
||||
}
|
||||
|
||||
// For A argument to MatMul (activations).
|
||||
template <typename T>
|
||||
ConstMat<T> ConstMatFromBatch(size_t batch_size,
|
||||
const RowVectorBatch<T>& row_vectors) {
|
||||
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
|
||||
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
|
||||
Extents2D(batch_size, row_vectors.Cols()));
|
||||
}
|
||||
|
||||
// For C argument to MatMul.
|
||||
template <typename T>
|
||||
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
|
||||
return RowPtr<T>(row_vectors.All(), row_vectors.Cols());
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
|
||||
|
|
|
|||
|
|
@ -0,0 +1,400 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "util/threading.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <atomic>
|
||||
#include <memory> // std::make_unique
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for container detection, do not remove
|
||||
#include "util/basics.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Sort T := packages/clusters by descending 'size' so that users who only use
|
||||
// one Group get the largest.
|
||||
template <class T>
|
||||
static void SortByDescendingSize(std::vector<T>& groups) {
|
||||
std::sort(groups.begin(), groups.end(),
|
||||
[](const T& a, const T& b) { return a.Size() > b.Size(); });
|
||||
}
|
||||
|
||||
BoundedTopology::BoundedTopology(BoundedSlice package_slice,
|
||||
BoundedSlice cluster_slice,
|
||||
BoundedSlice lp_slice) {
|
||||
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
|
||||
LPS enabled_lps;
|
||||
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
|
||||
const size_t num_lps = hwy::TotalLogicalProcessors();
|
||||
fprintf(stderr,
|
||||
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
|
||||
num_lps);
|
||||
for (size_t lp = 0; lp < num_lps; ++lp) {
|
||||
enabled_lps.Set(lp);
|
||||
}
|
||||
}
|
||||
|
||||
// Without threading support, only keep the first enabled LP; it might still
|
||||
// make sense to pin the main thread to avoid migrations.
|
||||
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
|
||||
HWY_ASSERT(enabled_lps.Any());
|
||||
const size_t lp = enabled_lps.First();
|
||||
enabled_lps = LPS();
|
||||
enabled_lps.Set(lp);
|
||||
fprintf(stderr,
|
||||
"Warning, threads not supported, using only the main thread\n.");
|
||||
}
|
||||
|
||||
#if !GEMMA_DISABLE_TOPOLOGY
|
||||
if (HWY_LIKELY(!topology_.packages.empty())) {
|
||||
InitFromTopology(enabled_lps, package_slice, cluster_slice);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Topology unknown or no packages with enabled LPs: create a single
|
||||
// package with one cluster, and one node.
|
||||
if (HWY_UNLIKELY(NumPackages() == 0)) {
|
||||
InitFromSlice(enabled_lps, lp_slice);
|
||||
}
|
||||
|
||||
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
|
||||
}
|
||||
|
||||
// Topology is unknown, rely on OS affinity and user-specified slice.
|
||||
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
||||
BoundedSlice lp_slice) {
|
||||
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
|
||||
// we honor both the OS affinity and the user-specified slice. Note that
|
||||
// this can be used to exclude hyperthreads because Linux groups LPs by
|
||||
// sibling index. For example, the first `num_cores` are not siblings.
|
||||
const size_t detected = enabled_lps.Count();
|
||||
size_t enabled_idx = 0;
|
||||
enabled_lps.Foreach([&](size_t lp) {
|
||||
if (lp_slice.Contains(detected, enabled_idx++)) {
|
||||
AddLP(lp);
|
||||
}
|
||||
});
|
||||
|
||||
// lp_slice can only reduce the number of `enabled_lps`, and not below 1.
|
||||
HWY_ASSERT(num_workers_ != 0);
|
||||
}
|
||||
|
||||
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
||||
const std::vector<hwy::Topology::LP>& all_lps,
|
||||
const hwy::Topology::Cluster& tcluster) {
|
||||
bool is_first_lp = true;
|
||||
|
||||
tcluster.lps.Foreach([&](size_t lp) {
|
||||
// Skip if not first-hyperthread or disabled.
|
||||
if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
|
||||
|
||||
AddLP(lp);
|
||||
|
||||
// Set `node` once, and ensure subsequent nodes match - we assume there
|
||||
// is only one NUMA node per cluster.
|
||||
const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
|
||||
if (is_first_lp) {
|
||||
is_first_lp = false;
|
||||
node_ = lp_node;
|
||||
} else {
|
||||
static bool warned = false;
|
||||
if (lp_node != node_ && !warned) {
|
||||
warned = true;
|
||||
fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n",
|
||||
lp, lp_node, node_);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// NOTE: caller is responsible for checking whether `clusters` is empty.
|
||||
BoundedTopology::Package::Package(const LPS& enabled_lps,
|
||||
const hwy::Topology& topology,
|
||||
size_t package_idx,
|
||||
BoundedSlice cluster_slice) {
|
||||
const hwy::Topology::Package& tpackage = topology.packages[package_idx];
|
||||
// Populate `clusters` with the subset of clusters in `cluster_slice` that
|
||||
// have any enabled LPs. If `clusters` remains empty, the caller will
|
||||
// skip this `Package`.
|
||||
clusters.reserve(cluster_slice.Num(tpackage.clusters.size()));
|
||||
cluster_slice.Foreach(
|
||||
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
|
||||
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx];
|
||||
Cluster cluster(enabled_lps, topology.lps, tcluster);
|
||||
// Skip if empty, i.e. too few `enabled_lps`.
|
||||
if (HWY_LIKELY(cluster.Size() != 0)) {
|
||||
clusters.push_back(std::move(cluster));
|
||||
}
|
||||
});
|
||||
SortByDescendingSize(clusters);
|
||||
}
|
||||
|
||||
#if !GEMMA_DISABLE_TOPOLOGY
|
||||
|
||||
static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) {
|
||||
LPS cores;
|
||||
lps.Foreach([&](size_t lp) {
|
||||
if (topology.lps[lp].smt == 0) cores.Set(lp);
|
||||
});
|
||||
return cores.Count();
|
||||
}
|
||||
|
||||
// Scans hwy::Topology for clusters and their size, for use by topology_string_.
|
||||
static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters,
|
||||
size_t& max_tcluster_cores,
|
||||
size_t& max_tcluster_lps) {
|
||||
max_tclusters = 0;
|
||||
max_tcluster_cores = 0;
|
||||
max_tcluster_lps = 0;
|
||||
for (size_t package_idx = 0; package_idx < topology_.packages.size();
|
||||
++package_idx) {
|
||||
const std::vector<hwy::Topology::Cluster>& tclusters =
|
||||
topology_.packages[package_idx].clusters;
|
||||
max_tclusters = HWY_MAX(max_tclusters, tclusters.size());
|
||||
size_t tcluster_cores = 0;
|
||||
size_t tcluster_lps = 0;
|
||||
for (size_t cluster_idx = 0; cluster_idx < tclusters.size();
|
||||
++cluster_idx) {
|
||||
const size_t cores = CoresFromLPs(tclusters[cluster_idx].lps, topology_);
|
||||
const size_t lps = tclusters[cluster_idx].lps.Count();
|
||||
tcluster_cores = HWY_MAX(tcluster_cores, cores);
|
||||
tcluster_lps = HWY_MAX(tcluster_lps, lps);
|
||||
}
|
||||
|
||||
if (tclusters.size() > 1 && tcluster_cores > 8) {
|
||||
fprintf(stderr,
|
||||
"Package %zu: multiple clusters with max size %zu, whereas CCX "
|
||||
"only have 8, may indicate a bug in hwy::Topology.\n",
|
||||
package_idx, tcluster_cores);
|
||||
}
|
||||
max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores);
|
||||
max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps);
|
||||
}
|
||||
HWY_ASSERT(max_tclusters != 0);
|
||||
HWY_ASSERT(max_tcluster_cores != 0);
|
||||
HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores);
|
||||
}
|
||||
|
||||
// Main part of ctor, called when topology is known.
|
||||
void BoundedTopology::InitFromTopology(const LPS& enabled_lps,
|
||||
BoundedSlice package_slice,
|
||||
BoundedSlice cluster_slice) {
|
||||
size_t max_tclusters, max_tcluster_cores, max_tcluster_lps;
|
||||
ScanTClusters(topology_, max_tclusters, max_tcluster_cores, max_tcluster_lps);
|
||||
|
||||
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
|
||||
package_slice.Foreach(
|
||||
"package", topology_.packages.size(), [&](size_t package_idx) {
|
||||
Package package(enabled_lps, topology_, package_idx, cluster_slice);
|
||||
// Skip if empty, i.e. too few `enabled_lps`.
|
||||
if (HWY_LIKELY(!package.clusters.empty())) {
|
||||
packages_.push_back(std::move(package));
|
||||
}
|
||||
});
|
||||
if (NumPackages() == 0) return;
|
||||
SortByDescendingSize(packages_);
|
||||
|
||||
// Remember NUMA nodes that we are actually using (not just enabled).
|
||||
for (const Package& p : packages_) {
|
||||
for (const Cluster& c : p.clusters) {
|
||||
nodes_.Set(c.Node());
|
||||
}
|
||||
}
|
||||
|
||||
// Scan for max BoundedTopology clusters and their size, for topology_string_.
|
||||
size_t all_max_cluster_size = 0;
|
||||
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
|
||||
size_t max_cluster_size = 0;
|
||||
for (size_t cluster_idx = 0; cluster_idx < NumClusters(package_idx);
|
||||
++cluster_idx) {
|
||||
max_cluster_size = HWY_MAX(max_cluster_size,
|
||||
GetCluster(package_idx, cluster_idx).Size());
|
||||
}
|
||||
if (NumClusters(package_idx) > 1 && max_cluster_size > 8) {
|
||||
fprintf(stderr,
|
||||
"Package %zu: multiple clusters with max size %zu, whereas CCX "
|
||||
"only have 8, may indicate a bug in BoundedTopology.\n",
|
||||
package_idx, max_cluster_size);
|
||||
}
|
||||
all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size);
|
||||
}
|
||||
|
||||
snprintf(topology_string_, sizeof(topology_string_),
|
||||
"%zuS %zuX %zuC %zuH, using %zuS %zuX %zuC (nodes=%zu)",
|
||||
topology_.packages.size(), max_tclusters, max_tcluster_cores,
|
||||
max_tcluster_lps / max_tcluster_cores, packages_.size(),
|
||||
NumClusters(0), all_max_cluster_size, nodes_.Count());
|
||||
}
|
||||
|
||||
#endif // !GEMMA_DISABLE_TOPOLOGY
|
||||
|
||||
void BoundedTopology::InitFromSlice(const LPS& enabled_lps,
|
||||
BoundedSlice lp_slice) {
|
||||
packages_.push_back(Package(enabled_lps, lp_slice));
|
||||
|
||||
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
|
||||
GetCluster(0, 0).Size());
|
||||
|
||||
// Assume a single NUMA node.
|
||||
nodes_.Set(0);
|
||||
HWY_ASSERT(NumNodes() == 1);
|
||||
}
|
||||
|
||||
static PoolPtr MakePool(size_t num_workers) {
|
||||
// `ThreadPool` expects the number of threads to create, which is one less
|
||||
// than the number of workers, but avoid underflow if zero.
|
||||
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
|
||||
return std::make_unique<hwy::ThreadPool>(num_threads);
|
||||
}
|
||||
|
||||
static bool InContainer() {
|
||||
return false;}
|
||||
|
||||
class NestedPools::Pinning {
|
||||
public:
|
||||
Pinning(Tristate pin, const BoundedTopology& topology) {
|
||||
if (pin == Tristate::kDefault) {
|
||||
// Pinning is unreliable inside containers because the hypervisor might
|
||||
// periodically change our affinity mask, or other processes might also
|
||||
// pin themselves to the same LPs.
|
||||
pin = InContainer() ? Tristate::kFalse : Tristate::kTrue;
|
||||
}
|
||||
want_pin_ = (pin == Tristate::kTrue);
|
||||
}
|
||||
|
||||
// If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`,
|
||||
// and sets `any_error_` if any fails.
|
||||
void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) {
|
||||
if (HWY_UNLIKELY(!want_pin_)) return;
|
||||
|
||||
const std::vector<size_t> lps = cluster.LPVector();
|
||||
HWY_ASSERT(pool->NumWorkers() <= lps.size());
|
||||
pool->Run(
|
||||
0, pool->NumWorkers(),
|
||||
[this, &pool, &lps](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each worker has one task
|
||||
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
|
||||
fprintf(stderr,
|
||||
"Pinning failed for task %zu of %zu to %zu (size %zu)\n",
|
||||
task, pool->NumWorkers(), lps[task], lps.size());
|
||||
(void)any_error_.test_and_set();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
bool WantPin() const { return want_pin_; }
|
||||
|
||||
// Called ONCE after all MaybePin because it invalidates the error status.
|
||||
bool AllPinned() {
|
||||
// If !want_pin_, MaybePin will return without setting any_error_, but in
|
||||
// that case we still want to return false to avoid spinning.
|
||||
// .test() was only added in C++20, so we use .test_and_set() instead.
|
||||
return want_pin_ && !any_error_.test_and_set();
|
||||
}
|
||||
|
||||
private:
|
||||
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
|
||||
bool want_pin_; // set in ctor
|
||||
}; // Pinning
|
||||
|
||||
// Used to divide max_threads and max_workers_per_package across packages and
|
||||
// clusters. Ensures small upper bounds are respected.
|
||||
static size_t DivideMaxAcross(const size_t max, const size_t instances) {
|
||||
// No limit.
|
||||
if (max == 0) return 0;
|
||||
// We have enough to distribute.
|
||||
if (max >= instances) return max / instances;
|
||||
// Use max as the upper bound for each instance because division would return
|
||||
// zero, which means 'unlimited'.
|
||||
return max;
|
||||
}
|
||||
|
||||
NestedPools::NestedPools(size_t max_threads, Tristate pin,
|
||||
BoundedSlice package_slice, BoundedSlice cluster_slice,
|
||||
BoundedSlice lp_slice)
|
||||
: topology_(package_slice, cluster_slice, lp_slice) {
|
||||
Pinning pinning(pin, topology_);
|
||||
packages_.resize(topology_.NumPackages());
|
||||
all_packages_ = MakePool(packages_.size());
|
||||
const size_t max_workers_per_package =
|
||||
DivideMaxAcross(max_threads, packages_.size());
|
||||
// Each worker in all_packages_, including the main thread, will be the
|
||||
// calling thread of an all_clusters->Run, and hence pinned to one of the
|
||||
// `cluster.lps` if `pin`.
|
||||
all_packages_->Run(
|
||||
0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) {
|
||||
HWY_ASSERT(package_idx == thread); // each thread has one task
|
||||
packages_[package_idx] = Package(
|
||||
topology_, package_idx, max_workers_per_package, pinning, lp_slice);
|
||||
});
|
||||
|
||||
all_pinned_ = pinning.AllPinned();
|
||||
pin_string_ = all_pinned_ ? "pinned"
|
||||
: pinning.WantPin() ? "pinning failed"
|
||||
: "pinning skipped";
|
||||
|
||||
// For mapping package/cluster/thread to noncontiguous TLS indices, in case
|
||||
// cluster/thread counts differ.
|
||||
HWY_ASSERT(!packages_.empty() && packages_.size() <= 16);
|
||||
for (const Package& p : packages_) {
|
||||
max_clusters_per_package_ =
|
||||
HWY_MAX(max_clusters_per_package_, p.NumClusters());
|
||||
max_workers_per_cluster_ =
|
||||
HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster());
|
||||
}
|
||||
HWY_ASSERT(max_clusters_per_package_ >= 1);
|
||||
HWY_ASSERT(max_clusters_per_package_ <= 64);
|
||||
HWY_ASSERT(max_workers_per_cluster_ >= 1);
|
||||
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
||||
}
|
||||
|
||||
// `max_or_zero` == 0 means no limit.
|
||||
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
||||
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
||||
}
|
||||
|
||||
NestedPools::Package::Package(const BoundedTopology& topology,
|
||||
size_t package_idx,
|
||||
size_t max_workers_per_package, Pinning& pinning,
|
||||
BoundedSlice lp_slice) {
|
||||
// Pre-allocate because elements are set concurrently.
|
||||
clusters_.resize(topology.NumClusters(package_idx));
|
||||
const size_t max_workers_per_cluster =
|
||||
DivideMaxAcross(max_workers_per_package, clusters_.size());
|
||||
|
||||
all_clusters_ = MakePool(clusters_.size());
|
||||
// Parallel so we also pin the calling worker in `all_clusters` to
|
||||
// `cluster.lps`.
|
||||
all_clusters_->Run(
|
||||
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
|
||||
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
||||
const BoundedTopology::Cluster& cluster =
|
||||
topology.GetCluster(package_idx, cluster_idx);
|
||||
clusters_[cluster_idx] =
|
||||
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
|
||||
// Pin workers AND the calling thread from `all_clusters`.
|
||||
pinning.MaybePin(cluster, clusters_[cluster_idx]);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
296
util/threading.h
296
util/threading.h
|
|
@ -17,17 +17,19 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <memory> // std::unique_ptr
|
||||
#include <utility> // std::move
|
||||
#include <memory> // std::unique_ptr
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
|
||||
#ifndef GEMMA_DISABLE_TOPOLOGY
|
||||
#define GEMMA_DISABLE_TOPOLOGY 0
|
||||
#endif // !GEMMA_DISABLE_TOPOLOGY
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// A slice of a 1D integer range such as the indices of packages or clusters.
|
||||
|
|
@ -74,6 +76,10 @@ class BoundedSlice {
|
|||
// "LP" is a logical processor, a 0-based index passed to the OS.
|
||||
using LPS = hwy::LogicalProcessorSet;
|
||||
|
||||
// We want vectors of hwy::ThreadPool, which is unfortunately not movable,
|
||||
// hence we wrap them in unique_ptr.
|
||||
using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
|
||||
|
||||
// Wraps hwy::Topology and only keeps the subset of packages and clusters
|
||||
// apportioned by BoundedSlice, further limited by the OS affinity mask.
|
||||
// NOTE: if topology is unknown or the OS affinity is too restrictive, we fall
|
||||
|
|
@ -81,94 +87,18 @@ using LPS = hwy::LogicalProcessorSet;
|
|||
class BoundedTopology {
|
||||
public:
|
||||
BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice,
|
||||
BoundedSlice lp_slice) {
|
||||
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
|
||||
LPS enabled_lps;
|
||||
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
|
||||
const size_t num_lps = hwy::TotalLogicalProcessors();
|
||||
fprintf(
|
||||
stderr,
|
||||
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
|
||||
num_lps);
|
||||
for (size_t lp = 0; lp < hwy::TotalLogicalProcessors(); ++lp) {
|
||||
enabled_lps.Set(lp);
|
||||
}
|
||||
}
|
||||
|
||||
// Without threading support, only keep the first enabled LP; it might still
|
||||
// make sense to pin the main thread.
|
||||
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
|
||||
HWY_ASSERT(enabled_lps.Any());
|
||||
const size_t lp = enabled_lps.First();
|
||||
enabled_lps = LPS();
|
||||
enabled_lps.Set(lp);
|
||||
}
|
||||
|
||||
if (HWY_LIKELY(!topology_.packages.empty())) {
|
||||
InitFromTopology(enabled_lps, package_slice, cluster_slice);
|
||||
}
|
||||
|
||||
// Topology unknown or no packages with enabled LPs: create a single
|
||||
// package with one cluster, and one node.
|
||||
if (HWY_UNLIKELY(NumPackages() == 0)) {
|
||||
InitFromSlice(enabled_lps, lp_slice);
|
||||
}
|
||||
|
||||
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
|
||||
}
|
||||
BoundedSlice lp_slice);
|
||||
|
||||
size_t NumPackages() const { return packages_.size(); }
|
||||
const char* TopologyString() const { return topology_string_; }
|
||||
size_t NumNodes() const { return nodes_.Count(); }
|
||||
const char* TopologyString() const { return topology_string_; }
|
||||
|
||||
class Cluster {
|
||||
public:
|
||||
// Topology is unknown, rely on OS affinity and user-specified slice.
|
||||
Cluster(const LPS& enabled_lps, BoundedSlice lp_slice) {
|
||||
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
|
||||
// we honor both the OS affinity and the user-specified slice. Note that
|
||||
// this can be used to exclude hyperthreads because Linux groups LPs by
|
||||
// sibling index. For example, the first `num_cores` are not siblings.
|
||||
const size_t detected = enabled_lps.Count();
|
||||
size_t enabled_idx = 0;
|
||||
enabled_lps.Foreach([&](size_t lp) {
|
||||
if (lp_slice.Contains(detected, enabled_idx++)) {
|
||||
AddLP(lp);
|
||||
}
|
||||
});
|
||||
|
||||
// lp_slice can only reduce the number of `enabled_lps`, and not below 1.
|
||||
HWY_ASSERT(num_workers_ != 0);
|
||||
}
|
||||
|
||||
Cluster(const LPS& enabled_lps, BoundedSlice lp_slice);
|
||||
Cluster(const LPS& enabled_lps,
|
||||
const std::vector<hwy::Topology::LP>& all_lps,
|
||||
const hwy::Topology::Cluster& tcluster) {
|
||||
bool is_first_lp = true;
|
||||
|
||||
tcluster.lps.Foreach([&](size_t lp) {
|
||||
// Skip if not first-hyperthread or disabled.
|
||||
if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
|
||||
|
||||
AddLP(lp);
|
||||
|
||||
// Set `node` once, and ensure subsequent nodes match - we assume there
|
||||
// is only one NUMA node per cluster.
|
||||
const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
|
||||
if (is_first_lp) {
|
||||
is_first_lp = false;
|
||||
node_ = lp_node;
|
||||
} else {
|
||||
static bool warned = false;
|
||||
if (lp_node != node_ && !warned) {
|
||||
warned = true;
|
||||
fprintf(stderr,
|
||||
"WARNING: lp %zu on node %zu != cluster node %zu.\n", lp,
|
||||
lp_node, node_);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
const hwy::Topology::Cluster& tcluster);
|
||||
|
||||
// For SortByDescendingSize.
|
||||
size_t Size() const { return num_workers_; }
|
||||
|
|
@ -215,53 +145,15 @@ class BoundedTopology {
|
|||
return package.clusters[cluster_idx];
|
||||
}
|
||||
|
||||
// Returns total number of cluster workers, for deciding whether to pin.
|
||||
size_t TotalWorkers() const {
|
||||
size_t total_workers = 0;
|
||||
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
|
||||
const size_t num_clusters = NumClusters(package_idx);
|
||||
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||
total_workers += GetCluster(package_idx, cluster_idx).Size();
|
||||
}
|
||||
}
|
||||
return total_workers;
|
||||
}
|
||||
|
||||
private:
|
||||
// Sort T := packages/clusters by descending 'size' so that users who only use
|
||||
// one Group get the largest.
|
||||
template <class T>
|
||||
static void SortByDescendingSize(std::vector<T>& groups) {
|
||||
std::sort(groups.begin(), groups.end(),
|
||||
[](const T& a, const T& b) { return a.Size() > b.Size(); });
|
||||
}
|
||||
|
||||
struct Package {
|
||||
// Topology is unknown, rely on OS affinity and user-specified slice.
|
||||
Package(const LPS& enabled_lps, BoundedSlice lp_slice) {
|
||||
clusters.push_back(Cluster(enabled_lps, lp_slice));
|
||||
}
|
||||
|
||||
// NOTE: caller is responsible for checking whether `clusters` is empty.
|
||||
Package(const LPS& enabled_lps, const hwy::Topology& topology,
|
||||
size_t package_idx, BoundedSlice cluster_slice) {
|
||||
const hwy::Topology::Package& tpackage = topology.packages[package_idx];
|
||||
// Populate `clusters` with the subset of clusters in `cluster_slice` that
|
||||
// have any enabled LPs. If `clusters` remains empty, the caller will
|
||||
// skip this `Package`.
|
||||
clusters.reserve(cluster_slice.Num(tpackage.clusters.size()));
|
||||
cluster_slice.Foreach(
|
||||
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
|
||||
const hwy::Topology::Cluster& tcluster =
|
||||
tpackage.clusters[cluster_idx];
|
||||
Cluster cluster(enabled_lps, topology.lps, tcluster);
|
||||
// Skip if empty, i.e. too few `enabled_lps`.
|
||||
if (HWY_LIKELY(cluster.Size() != 0)) {
|
||||
clusters.push_back(std::move(cluster));
|
||||
}
|
||||
});
|
||||
SortByDescendingSize(clusters);
|
||||
}
|
||||
size_t package_idx, BoundedSlice cluster_slice);
|
||||
|
||||
// For SortByDescendingSize.
|
||||
size_t Size() const { return clusters.size(); }
|
||||
|
|
@ -269,48 +161,13 @@ class BoundedTopology {
|
|||
std::vector<Cluster> clusters;
|
||||
}; // Package
|
||||
|
||||
// Main part of ctor, called when topology is known.
|
||||
void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice,
|
||||
BoundedSlice cluster_slice) {
|
||||
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
|
||||
package_slice.Foreach(
|
||||
"package", topology_.packages.size(), [&](size_t package_idx) {
|
||||
Package package(enabled_lps, topology_, package_idx, cluster_slice);
|
||||
// Skip if empty, i.e. too few `enabled_lps`.
|
||||
if (HWY_LIKELY(!package.clusters.empty())) {
|
||||
packages_.push_back(std::move(package));
|
||||
}
|
||||
});
|
||||
if (NumPackages() == 0) return;
|
||||
SortByDescendingSize(packages_);
|
||||
|
||||
const hwy::Topology::Package& tpackage0 = topology_.packages[0];
|
||||
HWY_ASSERT(!tpackage0.clusters.empty());
|
||||
const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0];
|
||||
// GetCluster(0, 0) is valid because only non-empty Packages were kept.
|
||||
snprintf(topology_string_, sizeof(topology_string_),
|
||||
"%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(),
|
||||
tpackage0.clusters.size(), tcluster0.lps.Count(), packages_.size(),
|
||||
NumClusters(0), GetCluster(0, 0).Size());
|
||||
|
||||
// Remember NUMA nodes of *enabled* LPs.
|
||||
enabled_lps.Foreach([&](size_t lp) {
|
||||
nodes_.Set(static_cast<size_t>(topology_.lps[lp].node));
|
||||
});
|
||||
}
|
||||
|
||||
void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice) {
|
||||
packages_.push_back(Package(enabled_lps, lp_slice));
|
||||
|
||||
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
|
||||
GetCluster(0, 0).Size());
|
||||
|
||||
// Assume a single NUMA node.
|
||||
nodes_.Set(0);
|
||||
HWY_ASSERT(NumNodes() == 1);
|
||||
}
|
||||
BoundedSlice cluster_slice);
|
||||
void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice);
|
||||
|
||||
#if !GEMMA_DISABLE_TOPOLOGY
|
||||
hwy::Topology topology_;
|
||||
#endif
|
||||
std::vector<Package> packages_;
|
||||
char topology_string_[96];
|
||||
LPS nodes_;
|
||||
|
|
@ -350,51 +207,32 @@ class NestedPools {
|
|||
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
|
||||
// only impose upper bounds on the number of detected packages and clusters
|
||||
// rather than defining the actual number of threads.
|
||||
//
|
||||
// `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically.
|
||||
NestedPools(size_t max_threads, int pin = -1,
|
||||
NestedPools(size_t max_threads, Tristate pin = Tristate::kDefault,
|
||||
BoundedSlice package_slice = BoundedSlice(),
|
||||
BoundedSlice cluster_slice = BoundedSlice(),
|
||||
BoundedSlice lp_slice = BoundedSlice())
|
||||
: topology_(package_slice, cluster_slice, lp_slice) {
|
||||
if (pin == -1) pin = topology_.TotalWorkers() >= 12;
|
||||
BoundedSlice lp_slice = BoundedSlice());
|
||||
|
||||
packages_.resize(topology_.NumPackages());
|
||||
all_packages_ = MakePool(packages_.size());
|
||||
const size_t max_workers_per_package = max_threads / packages_.size();
|
||||
// Each worker in all_packages_, including the main thread, will be the
|
||||
// calling thread of an all_clusters->Run, and hence pinned to one of the
|
||||
// `cluster.lps` if `pin`.
|
||||
all_packages_->Run(
|
||||
0, all_packages_->NumWorkers(),
|
||||
[&](uint64_t package_idx, size_t thread) {
|
||||
HWY_ASSERT(package_idx == thread); // each thread has one task
|
||||
packages_[package_idx] = Package(
|
||||
topology_, package_idx, max_workers_per_package, pin, lp_slice);
|
||||
});
|
||||
|
||||
// For mapping package/cluster/thread to noncontiguous TLS indices, in case
|
||||
// cluster/thread counts differ.
|
||||
HWY_ASSERT(!packages_.empty() && packages_.size() <= 16);
|
||||
for (const Package& p : packages_) {
|
||||
max_clusters_per_package_ =
|
||||
HWY_MAX(max_clusters_per_package_, p.NumClusters());
|
||||
max_workers_per_cluster_ =
|
||||
HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster());
|
||||
// Subject to `use_spinning`, enables spin waits with the goal of reducing the
|
||||
// latency of barrier synchronization. We only spin during Generate to avoid
|
||||
// wasting energy during long waits. If `use_spinning` is kDefault, we first
|
||||
// set it to kTrue or kFalse based on a heuristic.
|
||||
void MaybeStartSpinning(Tristate& use_spinning) {
|
||||
if (HWY_UNLIKELY(use_spinning == Tristate::kDefault)) {
|
||||
// The default is to only spin when pinning was enabled and supported by
|
||||
// the OS. Unless spin-waits have near-exclusive use of a core, the tail
|
||||
// latency can be higher than blocking waits.
|
||||
use_spinning = all_pinned_ ? Tristate::kTrue : Tristate::kFalse;
|
||||
}
|
||||
if (use_spinning == Tristate::kTrue) {
|
||||
SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
}
|
||||
}
|
||||
void MaybeStopSpinning(const Tristate use_spinning) {
|
||||
HWY_DASSERT(use_spinning != Tristate::kDefault); // see MaybeStartSpinning
|
||||
if (use_spinning == Tristate::kTrue) {
|
||||
SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
HWY_ASSERT(max_clusters_per_package_ >= 1);
|
||||
HWY_ASSERT(max_clusters_per_package_ <= 64);
|
||||
HWY_ASSERT(max_workers_per_cluster_ >= 1);
|
||||
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
||||
}
|
||||
|
||||
// Spinning reduces the latency of barrier synchronization, but wastes lots
|
||||
// of energy for long waits, so only do it during generation. Spinning might
|
||||
// also be unsafe in virtualized environments because we require threads to
|
||||
// be running on their own core and thus responsive to the barrier
|
||||
// synchronization.
|
||||
void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); }
|
||||
void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); }
|
||||
|
||||
hwy::ThreadPool& AllPackages() { return *all_packages_; }
|
||||
hwy::ThreadPool& AllClusters(size_t package_idx) {
|
||||
|
|
@ -425,7 +263,9 @@ class NestedPools {
|
|||
|
||||
// For Allocator
|
||||
const BoundedTopology& Topology() const { return topology_; }
|
||||
// For ShowConfig
|
||||
const char* TopologyString() const { return topology_.TopologyString(); }
|
||||
const char* PinString() const { return pin_string_; }
|
||||
|
||||
// Returns a single pool on the first package: either one thread per cluster
|
||||
// if there is more than one, which maximizes available memory bandwidth, or
|
||||
|
|
@ -439,56 +279,14 @@ class NestedPools {
|
|||
}
|
||||
|
||||
private:
|
||||
// `max_or_zero` == 0 means no limit.
|
||||
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
||||
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
||||
}
|
||||
|
||||
// We want vectors of hwy::ThreadPool, which is unfortunately not movable,
|
||||
// hence we wrap them in unique_ptr.
|
||||
using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
|
||||
|
||||
static PoolPtr MakePool(size_t num_workers) {
|
||||
// `ThreadPool` expects the number of threads to create, which is one less
|
||||
// than the number of workers, but avoid underflow if zero.
|
||||
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
|
||||
return std::make_unique<hwy::ThreadPool>(num_threads);
|
||||
}
|
||||
class Pinning;
|
||||
|
||||
class Package {
|
||||
public:
|
||||
Package() = default; // for vector
|
||||
Package(const BoundedTopology& topology, size_t package_idx,
|
||||
size_t max_workers_per_package, int pin, BoundedSlice lp_slice) {
|
||||
// Pre-allocate because elements are set concurrently.
|
||||
clusters_.resize(topology.NumClusters(package_idx));
|
||||
const size_t max_workers_per_cluster =
|
||||
max_workers_per_package / clusters_.size();
|
||||
|
||||
all_clusters_ = MakePool(clusters_.size());
|
||||
// Parallel so we also pin the calling worker in `all_clusters` to
|
||||
// `cluster.lps`.
|
||||
all_clusters_->Run(
|
||||
0, all_clusters_->NumWorkers(),
|
||||
[&](size_t cluster_idx, size_t thread) {
|
||||
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
||||
const BoundedTopology::Cluster& cluster =
|
||||
topology.GetCluster(package_idx, cluster_idx);
|
||||
clusters_[cluster_idx] =
|
||||
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
|
||||
if (HWY_LIKELY(pin)) {
|
||||
// Pin threads AND the calling thread from `all_clusters` to lps.
|
||||
const std::vector<size_t> lps = cluster.LPVector();
|
||||
HWY_ASSERT(clusters_[cluster_idx]->NumWorkers() <= lps.size());
|
||||
clusters_[cluster_idx]->Run(
|
||||
0, clusters_[cluster_idx]->NumWorkers(),
|
||||
[&lps](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each worker has one task
|
||||
hwy::PinThreadToLogicalProcessor(lps[task]);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
size_t max_workers_per_package, Pinning& pinning,
|
||||
BoundedSlice lp_slice);
|
||||
|
||||
size_t NumClusters() const { return clusters_.size(); }
|
||||
size_t MaxWorkersPerCluster() const {
|
||||
|
|
@ -526,6 +324,8 @@ class NestedPools {
|
|||
}
|
||||
|
||||
BoundedTopology topology_;
|
||||
bool all_pinned_;
|
||||
const char* pin_string_;
|
||||
|
||||
std::vector<Package> packages_;
|
||||
PoolPtr all_packages_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue