Simpler MatMul interface, vocab types, Tristate for use_spinning

Add Extents2D, Range2D vocab types
Matmul uses ConstMat for inputs and RowPtr for output
Move RowVectorBatch to basics.h
Separate threading.cc
Fix topology string: report cores not LPs, and #HT
Move QStride/IsMHA into LayerConfig
ImageTokens does not require make_unique.
matmul_test: no longer require template args
PiperOrigin-RevId: 692963605
This commit is contained in:
Jan Wassenberg 2024-11-04 07:47:49 -08:00 committed by Copybara-Service
parent baaa221787
commit 868b01601f
26 changed files with 1311 additions and 971 deletions

View File

@ -30,8 +30,11 @@ cc_library(
cc_library( cc_library(
name = "threading", name = "threading",
srcs = ["util/threading.cc"],
hdrs = ["util/threading.h"], hdrs = ["util/threading.h"],
deps = [ deps = [
":basics",
# Placeholder for container detection, do not remove
"@highway//:hwy", "@highway//:hwy",
"@highway//:thread_pool", "@highway//:thread_pool",
"@highway//:topology", "@highway//:topology",
@ -173,7 +176,9 @@ cc_test(
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":allocator", ":allocator",
":basics",
":ops", ":ops",
":test_util",
":threading", ":threading",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress", "//compression:compress",
@ -280,6 +285,7 @@ cc_library(
":kv_cache", ":kv_cache",
":weights", ":weights",
":threading", ":threading",
"//compression:compress",
"//compression:io", "//compression:io",
"//compression:sfp", "//compression:sfp",
"//paligemma:image", "//paligemma:image",
@ -307,6 +313,7 @@ cc_library(
name = "args", name = "args",
hdrs = ["util/args.h"], hdrs = ["util/args.h"],
deps = [ deps = [
":basics",
"//compression:io", "//compression:io",
"@highway//:hwy", "@highway//:hwy",
], ],
@ -317,6 +324,7 @@ cc_library(
hdrs = ["util/app.h"], hdrs = ["util/app.h"],
deps = [ deps = [
":args", ":args",
":basics",
":common", ":common",
":gemma_lib", ":gemma_lib",
":threading", ":threading",
@ -342,8 +350,6 @@ cc_library(
"//compression:compress", "//compression:compress",
"@highway//:hwy", "@highway//:hwy",
"@highway//:nanobenchmark", "@highway//:nanobenchmark",
"@highway//:thread_pool",
"@highway//:topology",
], ],
) )
@ -583,6 +589,7 @@ cc_test(
}, },
deps = [ deps = [
":backprop", ":backprop",
":basics",
":common", ":common",
":gemma_lib", ":gemma_lib",
":optimizer", ":optimizer",

View File

@ -101,6 +101,7 @@ set(SOURCES
util/args.h util/args.h
util/basics.h util/basics.h
util/test_util.h util/test_util.h
util/threading.cc
util/threading.h util/threading.h
) )

View File

@ -33,13 +33,15 @@
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/basics.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
TEST(OptimizeTest, GradientDescent) { TEST(OptimizeTest, GradientDescent) {
NestedPools pools(1, /*pin=*/0, BoundedSlice(0, 1), BoundedSlice(0, 1)); NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 1));
hwy::ThreadPool& pool = pools.Pool(); hwy::ThreadPool& pool = pools.Pool();
std::mt19937 gen(42); std::mt19937 gen(42);

View File

@ -33,6 +33,7 @@
#include "compression/blob_store.h" #include "compression/blob_store.h"
#include "compression/io.h" #include "compression/io.h"
#include "compression/shared.h" #include "compression/shared.h"
#include "util/basics.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "util/allocator.h" #include "util/allocator.h"
#if COMPRESS_STATS #if COMPRESS_STATS
@ -62,7 +63,9 @@ class MatPtr {
num_elements_(rows * cols), num_elements_(rows * cols),
rows_(rows), rows_(rows),
cols_(cols), cols_(cols),
ptr_(nullptr) {} ptr_(nullptr) {
stride_ = cols;
}
// Default is to leave all fields default-initialized. // Default is to leave all fields default-initialized.
MatPtr() = default; MatPtr() = default;
virtual ~MatPtr(); virtual ~MatPtr();
@ -85,7 +88,9 @@ class MatPtr {
element_size_(key2.hi), element_size_(key2.hi),
num_elements_(key2.lo), num_elements_(key2.lo),
rows_(key3.lo), rows_(key3.lo),
cols_(key3.hi) {} cols_(key3.hi) {
stride_ = cols_;
}
// Adds the contents entry to the table of contents. // Adds the contents entry to the table of contents.
void AddToToc(std::vector<hwy::uint128_t>& toc) const { void AddToToc(std::vector<hwy::uint128_t>& toc) const {
@ -137,6 +142,12 @@ class MatPtr {
// Returns the number of columns in the 2-d array (inner dimension). // Returns the number of columns in the 2-d array (inner dimension).
size_t Cols() const { return cols_; } size_t Cols() const { return cols_; }
Extents2D Extents() const { return Extents2D(rows_, cols_); }
// Currently same as cols, but may differ in the future. This is the offset by
// which to advance pointers to the next row.
size_t Stride() const { return stride_; }
// Decoded elements should be multiplied by this to restore their original // Decoded elements should be multiplied by this to restore their original
// range. This is required because SfpStream can only encode a limited range // range. This is required because SfpStream can only encode a limited range
// of magnitudes. // of magnitudes.
@ -187,6 +198,8 @@ class MatPtr {
// freed. The underlying memory is owned by a subclass or some external class // freed. The underlying memory is owned by a subclass or some external class
// and must outlive this object. // and must outlive this object.
void* ptr_ = nullptr; void* ptr_ = nullptr;
size_t stride_;
}; };
// MatPtrT adds a single template argument to MatPtr for an explicit type. // MatPtrT adds a single template argument to MatPtr for an explicit type.
@ -288,7 +301,15 @@ decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
} }
} }
template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
mat.scale = m.scale();
return mat;
}
// MatStorageT adds the actual data storage to MatPtrT. // MatStorageT adds the actual data storage to MatPtrT.
// TODO: use Extents2D instead of rows and cols.
template <typename MatT> template <typename MatT>
class MatStorageT : public MatPtrT<MatT> { class MatStorageT : public MatPtrT<MatT> {
public: public:

View File

@ -267,8 +267,12 @@ struct PackedSpan {
// check the compressed count and ensure we have that many. // check the compressed count and ensure we have that many.
const size_t required = const size_t required =
CompressedArrayElements<Packed>(packed_ofs + num_accessible); CompressedArrayElements<Packed>(packed_ofs + num_accessible);
HWY_DASSERT(num >= required); if constexpr (HWY_IS_DEBUG_BUILD) {
(void)required; if (num < required) {
HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed",
packed_ofs, num_accessible, required, num);
}
}
} }
Packed* HWY_RESTRICT ptr; Packed* HWY_RESTRICT ptr;

View File

@ -229,12 +229,12 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
fprintf(stderr, fprintf(stderr,
"Date & Time : %s" // dt includes \n "Date & Time : %s" // dt includes \n
"CPU : %s\n" "CPU : %s\n"
"CPU topology : %s\n" "CPU topology : %s, %s\n"
"Instruction set : %s (%zu bits)\n" "Instruction set : %s (%zu bits)\n"
"Compiled config : %s\n" "Compiled config : %s\n"
"Weight Type : %s\n" "Weight Type : %s\n"
"EmbedderInput Type : %s\n", "EmbedderInput Type : %s\n",
dt, cpu100, pools.TopologyString(), dt, cpu100, pools.TopologyString(), pools.PinString(),
hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8, hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8,
CompiledConfig(), StringFromType(loader.Info().weight), CompiledConfig(), StringFromType(loader.Info().weight),
TypeName<EmbedderInputT>()); TypeName<EmbedderInputT>());

View File

@ -72,18 +72,11 @@ struct Activations {
size_t seq_len; size_t seq_len;
size_t cache_pos_size = 0; size_t cache_pos_size = 0;
// Multi-Head Attention?
bool IsMHA() const { return layer_config.heads == layer_config.kv_heads; }
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
size_t QStride() const { return layer_config.qkv_dim * (IsMHA() ? 3 : 1); }
static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim, static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim,
PostQKType post_qk) { PostQKType post_qk) {
const size_t rope_dim = const size_t rope_dim =
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim; post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(1, rope_dim / 2); RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
for (size_t dim = 0; dim < rope_dim / 2; ++dim) { for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const float freq_exponents = const float freq_exponents =
static_cast<float>(2 * dim) / static_cast<float>(rope_dim); static_cast<float>(2 * dim) / static_cast<float>(rope_dim);
@ -100,29 +93,31 @@ struct Activations {
const size_t ff_hidden_dim = layer_config.ff_hidden_dim; const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
const size_t vocab_size = weights_config.vocab_size; const size_t vocab_size = weights_config.vocab_size;
x = RowVectorBatch<float>(batch_size, model_dim); x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
q = RowVectorBatch<float>(batch_size, layer_config.heads * QStride()); q = RowVectorBatch<float>(
Extents2D(batch_size, layer_config.heads * layer_config.QStride()));
if (vocab_size > 0) { if (vocab_size > 0) {
logits = RowVectorBatch<float>(batch_size, vocab_size); logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
} }
pre_att_rms_out = RowVectorBatch<float>(batch_size, model_dim); pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
att = RowVectorBatch<float>(batch_size, att = RowVectorBatch<float>(
layer_config.heads * weights_config.seq_len); Extents2D(batch_size, layer_config.heads * weights_config.seq_len));
att_out = RowVectorBatch<float>(batch_size, att_out = RowVectorBatch<float>(
layer_config.heads * layer_config.qkv_dim); Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim));
att_sums = RowVectorBatch<float>(batch_size, model_dim); att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, model_dim); bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
C1 = RowVectorBatch<float>(batch_size, ff_hidden_dim); C1 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
C2 = RowVectorBatch<float>(batch_size, ff_hidden_dim); C2 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
ffw_out = RowVectorBatch<float>(batch_size, model_dim); ffw_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
griffin_x = RowVectorBatch<float>(batch_size, model_dim); griffin_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_y = RowVectorBatch<float>(batch_size, model_dim); griffin_y = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_gate_x = RowVectorBatch<float>(batch_size, model_dim); griffin_gate_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_multiplier = RowVectorBatch<float>(batch_size, model_dim); griffin_multiplier =
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
} }
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk); inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);

View File

@ -119,6 +119,13 @@ enum class Model {
struct LayerConfig { struct LayerConfig {
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; } size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
// Multi-Head Attention?
bool IsMHA() const { return heads == kv_heads; }
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
size_t model_dim = 0; size_t model_dim = 0;
size_t griffin_dim = 0; size_t griffin_dim = 0;
size_t ff_hidden_dim = 0; size_t ff_hidden_dim = 0;

View File

@ -20,9 +20,9 @@
#include <stdio.h> #include <stdio.h>
#include <algorithm> // std::min #include <algorithm> // std::min
#include <type_traits>
#include <vector> #include <vector>
#include "compression/compress.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/configs.h"
@ -31,6 +31,7 @@
// Placeholder for internal test4, do not remove // Placeholder for internal test4, do not remove
#include "paligemma/image.h" #include "paligemma/image.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -232,49 +233,49 @@ class GemmaAttention {
// KV directly to KVCache. // KV directly to KVCache.
HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) { HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.QKV"); PROFILER_ZONE("Gen.Attention.QKV");
// For the computation of Q, K, and V, it is useful to remember that const size_t model_dim = layer_config_.model_dim;
// qkv_einsum_w has shape [(layer_config_.heads + layer_config_.kv_heads * const size_t qkv_dim = layer_config_.qkv_dim;
// 2), kKQVDim, layer_config_.model_dim] and q_stride_ = const size_t heads = layer_config_.heads;
// layer_config_.qkv_dim * (is_mha_ ? 3 : 1); const size_t kv_heads = layer_config_.kv_heads;
const auto pre_att_rms_out = const auto pre_att_rms_out =
ConstMat(activations_.pre_att_rms_out.All(), layer_config_.model_dim); ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out);
const auto w_q1 = layer_weights_.qkv_einsum_w.data() == nullptr auto w_q1 = layer_weights_.qkv_einsum_w.data()
? ConstMat(layer_weights_.qkv_einsum_w1.data(), ? ConstMatFromWeights(layer_weights_.qkv_einsum_w)
layer_config_.model_dim) : ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
: ConstMat(layer_weights_.qkv_einsum_w.data(), // The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
layer_config_.model_dim); // model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
const auto w_q2 = // We must shrink to the actual size because MatMul verifies
layer_weights_.qkv_einsum_w.data() == nullptr // `B.extents.rows == C.Cols()`. If MHA, `QStride() == 3 * qkv_dim` and all
? ConstMat(layer_weights_.qkv_einsum_w2.data(), // rows are used. Otherwise, `QStride() == qkv_dim` and KV will be
layer_config_.model_dim) // computed in the second MatMul.
: ConstMat(layer_weights_.qkv_einsum_w.data(), const size_t w1_rows = heads * layer_config_.QStride();
layer_config_.model_dim, layer_config_.model_dim, w_q1.ShrinkRows(w1_rows);
layer_config_.heads * layer_config_.qkv_dim * MatMul(pre_att_rms_out, w_q1,
layer_config_.model_dim); /*add=*/nullptr, activations_.env, RowPtrFromBatch(activations_.q));
MatMul</*kAdd=*/false>(
num_interleaved, pre_att_rms_out, w_q1,
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, activations_.env,
MutableMat(activations_.q.All(), layer_config_.heads * q_stride_));
if (is_mha_) { if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
} else { } else {
auto w_q2 = layer_weights_.qkv_einsum_w.data()
? ConstMatFromWeights(layer_weights_.qkv_einsum_w,
w1_rows * model_dim)
: ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
w_q2.ShrinkRows(w_rows_kv_cols);
// Single query and no wraparound means we can use a matmul and write // Single query and no wraparound means we can use a matmul and write
// directly into the KV cache with a stride of cache_pos_size_. // directly into the KV cache with a stride of cache_pos_size_.
if (num_queries_ == 1 && if (num_queries_ == 1 &&
queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) { queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) {
const size_t kv_ofs = const size_t kv_ofs =
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_; queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
// KV structure is [k, v, k, v, ....] = layer_config_.kv_heads pairs of
// (k, v).
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
MatMul</*kAdd=*/false>( RowPtrF kv_rows(kv, w_rows_kv_cols);
num_tokens_, pre_att_rms_out, w_q2, kv_rows.SetStride(cache_pos_size_);
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, MatMul(pre_att_rms_out, w_q2,
activations_.env, /*add=*/nullptr, activations_.env, kv_rows);
MutableMat(kv, layer_config_.kv_heads * 2 * layer_config_.qkv_dim,
cache_pos_size_));
} else { } else {
// Proceed row by row because there will be wraparound. // Proceed row by row because there will be wraparound.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
@ -288,40 +289,34 @@ class GemmaAttention {
const size_t kv_offset = const size_t kv_offset =
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_; cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// KV structure is [k, v, k, v, ....] = layer_config_.kv_heads pairs if (layer_weights_.qkv_einsum_w.data()) {
// of (k, v). MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim,
if (layer_weights_.qkv_einsum_w.data() == nullptr) { w_rows_kv_cols, model_dim, x, kv, pool_);
MatVec(layer_weights_.qkv_einsum_w2, 0,
layer_config_.kv_heads * 2 * layer_config_.qkv_dim,
layer_config_.model_dim, x, kv, pool_);
} else { } else {
MatVec(layer_weights_.qkv_einsum_w, MatVec(layer_weights_.qkv_einsum_w2, 0, //
layer_config_.heads * layer_config_.qkv_dim * w_rows_kv_cols, model_dim, x, kv, pool_);
layer_config_.model_dim,
layer_config_.kv_heads * 2 * layer_config_.qkv_dim,
layer_config_.model_dim, x, kv, pool_);
}
} }
} }
} }
} // !is_mha_
// Apply positional encodings for K (and copy KV to cache if MHA). // Apply positional encodings for K (and copy KV to cache if MHA).
pool_.Run(0, layer_config_.kv_heads * num_interleaved, pool_.Run(0, kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR { [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config_.kv_heads; const size_t head = task % kv_heads;
const size_t interleaved_idx = task / layer_config_.kv_heads; const size_t interleaved_idx = task / kv_heads;
const size_t query_idx = interleaved_idx % num_queries_; const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_; const size_t batch_idx = interleaved_idx / num_queries_;
const size_t pos = queries_pos_[query_idx] + batch_idx; const size_t pos = queries_pos_[query_idx] + batch_idx;
const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size_ + const size_t kv_offset = cache_pos * cache_pos_size_ +
layer_ * cache_layer_size_ + layer_ * cache_layer_size_ +
head * layer_config_.qkv_dim * 2; head * qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx]; KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
const float* HWY_RESTRICT mha_kv = const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ + activations_.q.Batch(interleaved_idx) + head * q_stride_ +
layer_config_.qkv_dim; qkv_dim;
// Copy from `q` if MHA, or apply in-place. // Copy from `q` if MHA, or apply in-place.
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
@ -329,9 +324,8 @@ class GemmaAttention {
// If MHA, also copy V into KVCache. // If MHA, also copy V into KVCache.
if (is_mha_) { if (is_mha_) {
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim, hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim,
kv + layer_config_.qkv_dim, qkv_dim * sizeof(*kv));
layer_config_.qkv_dim * sizeof(*kv));
} }
}); });
} }
@ -463,27 +457,14 @@ class GemmaAttention {
HWY_DASSERT(layer_weights_.att_weights.data() != nullptr); HWY_DASSERT(layer_weights_.att_weights.data() != nullptr);
HWY_DASSERT(activations_.att_out.All() != nullptr); HWY_DASSERT(activations_.att_out.All() != nullptr);
HWY_DASSERT(activations_.att_sums.All() != nullptr); HWY_DASSERT(activations_.att_sums.All() != nullptr);
if (layer_weights_.layer_config.softmax_attn_output_biases) {
MatMul</*kAdd=*/true>( const float* add =
num_interleaved, layer_weights_.layer_config.softmax_attn_output_biases
ConstMat(activations_.att_out.All(), ? layer_weights_.attention_output_biases.data_scale1()
layer_config_.heads * layer_config_.qkv_dim), : nullptr;
ConstMat(layer_weights_.att_weights.data(), MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
layer_config_.heads * layer_config_.qkv_dim), ConstMatFromWeights(layer_weights_.att_weights), add,
layer_weights_.att_weights.scale(), activations_.env, RowPtrFromBatch(activations_.att_sums));
layer_weights_.attention_output_biases.data_scale1(),
activations_.env,
MutableMat(activations_.att_sums.All(), layer_config_.model_dim));
} else {
MatMul</*kAdd=*/false>(
num_interleaved,
ConstMat(activations_.att_out.All(),
layer_config_.heads * layer_config_.qkv_dim),
ConstMat(layer_weights_.att_weights.data(),
layer_config_.heads * layer_config_.qkv_dim),
layer_weights_.att_weights.scale(), nullptr, activations_.env,
MutableMat(activations_.att_sums.All(), layer_config_.model_dim));
}
} }
public: public:
@ -524,13 +505,13 @@ class GemmaAttention {
num_queries_(queries_pos.size()), num_queries_(queries_pos.size()),
num_tokens_(num_tokens), num_tokens_(num_tokens),
layer_(layer), layer_(layer),
q_stride_(activations.QStride()), layer_config_(layer_weights->layer_config),
q_stride_(layer_config_.QStride()),
cache_layer_size_(layer_weights->layer_config.CacheLayerSize()), cache_layer_size_(layer_weights->layer_config.CacheLayerSize()),
cache_pos_size_(activations.cache_pos_size), cache_pos_size_(activations.cache_pos_size),
is_mha_(activations.IsMHA()), is_mha_(layer_config_.IsMHA()),
activations_(activations), activations_(activations),
layer_weights_(*layer_weights), layer_weights_(*layer_weights),
layer_config_(layer_weights->layer_config),
div_seq_len_(div_seq_len), div_seq_len_(div_seq_len),
kv_caches_(kv_caches), kv_caches_(kv_caches),
pool_(activations.env.Pool()) { pool_(activations.env.Pool()) {
@ -552,6 +533,7 @@ class GemmaAttention {
const size_t num_queries_; const size_t num_queries_;
const size_t num_tokens_; const size_t num_tokens_;
const size_t layer_; const size_t layer_;
const LayerConfig& layer_config_;
const size_t q_stride_ = 0; const size_t q_stride_ = 0;
const size_t cache_layer_size_ = 0; const size_t cache_layer_size_ = 0;
const size_t cache_pos_size_ = 0; const size_t cache_pos_size_ = 0;
@ -559,7 +541,6 @@ class GemmaAttention {
Activations& activations_; Activations& activations_;
const LayerWeightsPtrs<T>& layer_weights_; const LayerWeightsPtrs<T>& layer_weights_;
const LayerConfig& layer_config_;
const hwy::Divisor& div_seq_len_; const hwy::Divisor& div_seq_len_;
const KVCaches& kv_caches_; const KVCaches& kv_caches_;
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
@ -601,17 +582,13 @@ class VitAttention {
// Computes Q, K, V for all heads, stored in activations_.q. // Computes Q, K, V for all heads, stored in activations_.q.
HWY_NOINLINE void ComputeQKV() { HWY_NOINLINE void ComputeQKV() {
PROFILER_ZONE("Gen.VitAttention.QKV"); PROFILER_ZONE("Gen.VitAttention.QKV");
const auto y =
ConstMat(activations_.pre_att_rms_out.All(), layer_config_.model_dim);
auto& qkv = activations_.q; auto& qkv = activations_.q;
HWY_ASSERT(qkv.BatchSize() == num_tokens_); HWY_ASSERT(qkv.BatchSize() == num_tokens_);
HWY_ASSERT(qkv.Len() == layer_config_.heads * 3 * layer_config_.qkv_dim); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMul</*kAdd=*/true>( MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
num_tokens_, y, ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
ConstMat(layer_weights_.vit.qkv_einsum_w.data_scale1(), layer_weights_.vit.qkv_einsum_b.data_scale1(), activations_.env,
layer_config_.model_dim), RowPtrFromBatch(qkv));
/*scale=*/1.0f, layer_weights_.vit.qkv_einsum_b.data_scale1(),
activations_.env, MutableMat(qkv.All(), qkv.Len()));
} }
HWY_NOINLINE void DotSoftmaxWeightedSum() { HWY_NOINLINE void DotSoftmaxWeightedSum() {
@ -658,17 +635,13 @@ class VitAttention {
HWY_NOINLINE void SumHeads() { HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.VitAttention.SumHeads"); PROFILER_ZONE("Gen.VitAttention.SumHeads");
auto* bias = layer_weights_.vit.attn_out_b.data_scale1(); auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
auto att_out = ConstMat(activations_.att_out.All(),
layer_config_.heads * layer_config_.qkv_dim);
auto att_weights = ConstMat(layer_weights_.vit.attn_out_w.data_scale1(),
layer_config_.heads * layer_config_.qkv_dim);
auto att_sums =
MutableMat(activations_.att_sums.All(), layer_config_.model_dim);
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
// layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads. // matmul output is the sum over heads.
MatMul</*kAdd=*/true>(num_tokens_, att_out, att_weights, /*scale=*/1.0f, auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
bias, activations_.env, att_sums); auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
auto att_sums = RowPtrFromBatch(activations_.att_sums);
MatMul(att_out, att_weights, bias, activations_.env, att_sums);
} }
public: public:
@ -720,125 +693,94 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
PROFILER_ZONE("Gen.FFW"); PROFILER_ZONE("Gen.FFW");
const size_t model_dim = layer_weights->layer_config.model_dim; const size_t model_dim = layer_weights->layer_config.model_dim;
const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim; const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
const bool add_bias = layer_weights->layer_config.ff_biases;
using WeightType = T; using WeightType = T;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
// Define slightly more readable names for the weights and activations. const bool add_bias = layer_weights->layer_config.ff_biases;
const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), model_dim); const float* bias1 =
Mat<const WeightType> w1; add_bias ? layer_weights->ffw_gating_biases.data_scale1() : nullptr;
const float* bias1 = nullptr; const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
Mat<const WeightType> w2; const float* output_bias =
const float* bias2 = nullptr; add_bias ? layer_weights->ffw_output_biases.data_scale1() : nullptr;
float scale = 1.0f;
Mat<const WeightType> w_output;
const float* output_bias = nullptr;
float output_scale = 1.0f;
auto hidden_activations = MutableMat(activations.C1.All(), ffh_hidden_dim);
auto multiplier = MutableMat(activations.C2.All(), ffh_hidden_dim);
auto ffw_out = MutableMat(activations.ffw_out.All(), model_dim);
// For some of the weights and activations, it depends on the config where to // Define slightly more readable names for the weights and activations.
// get them from or whether to use them at all. const auto x =
bias1 = layer_weights->ffw_gating_biases.data_scale1(); ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
bias2 = bias1 + ffh_hidden_dim;
output_bias = layer_weights->ffw_output_biases.data_scale1(); auto hidden_activations = RowPtrFromBatch(activations.C1);
w1 = layer_weights->gating_einsum_w.data() == nullptr auto multiplier = RowPtrFromBatch(activations.C2);
? ConstMat(layer_weights->gating_einsum_w1.data(), model_dim) auto ffw_out = RowPtrFromBatch(activations.ffw_out);
: ConstMat(layer_weights->gating_einsum_w.data(), model_dim);
w2 = layer_weights->gating_einsum_w.data() == nullptr // gating_einsum_w holds two half-matrices. We plan to change the importer to
? ConstMat(layer_weights->gating_einsum_w2.data(), model_dim) // avoid this confusion by splitting into gating_einsum_w1 and
: ConstMat(layer_weights->gating_einsum_w.data(), model_dim, // gating_einsum_w2.
model_dim, model_dim * ffh_hidden_dim); const bool split = !!layer_weights->gating_einsum_w.data();
scale = layer_weights->gating_einsum_w.data() == nullptr auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w)
? layer_weights->gating_einsum_w1.scale() : ConstMatFromWeights(layer_weights->gating_einsum_w1);
: layer_weights->gating_einsum_w.scale(); auto w2 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w,
w_output = ConstMat(layer_weights->linear_w.data(), ffh_hidden_dim); model_dim * ffh_hidden_dim)
output_scale = layer_weights->linear_w.scale(); : ConstMatFromWeights(layer_weights->gating_einsum_w2);
if (split) {
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
w1.ShrinkRows(ffh_hidden_dim);
w2.ShrinkRows(ffh_hidden_dim);
}
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
// Compute the hidden layer activations. // Compute the hidden layer activations.
if (add_bias) { MatMul(x, w1, bias1, activations.env, hidden_activations);
MatMul</*kAddBias=*/true>(num_interleaved, x, w1, scale, bias1, MatMul(x, w2, bias2, activations.env, multiplier);
activations.env, hidden_activations);
MatMul</*kAddBias=*/true>(num_interleaved, x, w2, scale, bias2,
activations.env, multiplier);
} else {
MatMul</*kAddBias=*/false>(num_interleaved, x, w1, scale, bias1,
activations.env, hidden_activations);
MatMul</*kAddBias=*/false>(num_interleaved, x, w2, scale, bias2,
activations.env, multiplier);
}
// Activation (Gelu) and maybe multiply by gate. Store activations in act. // Activation (Gelu) and maybe multiply by gate. Store activations in act.
Activation(layer_weights->layer_config.activation, hidden_activations.ptr, Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
multiplier.ptr, ffh_hidden_dim * num_interleaved); multiplier.Row(0), ffh_hidden_dim * num_interleaved);
// Hidden layer -> output layer. // Hidden layer -> output layer.
if (add_bias) { auto activations_mat = MakeConstMat(
MatMul</*kAddBias=*/true>(num_interleaved, ConstMat(hidden_activations), hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));
w_output, output_scale, output_bias,
activations.env, ffw_out); MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
} else {
MatMul</*kAddBias=*/false>(num_interleaved, ConstMat(hidden_activations),
w_output, output_scale, output_bias,
activations.env, ffw_out);
}
} }
// Same as FFWNoVit, but with different layer_weights members and no second
// gating matrix.
template <typename T> template <typename T>
HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
const LayerWeightsPtrs<T>* layer_weights) { const LayerWeightsPtrs<T>* layer_weights) {
PROFILER_ZONE("Gen.FFW"); PROFILER_ZONE("Gen.FFW");
const size_t model_dim = layer_weights->layer_config.model_dim;
const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim; const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
const bool add_bias = layer_weights->layer_config.ff_biases;
using WeightType = typename LayerWeightsPtrs<T>::WeightF32OrBF16; using WeightType = typename LayerWeightsPtrs<T>::WeightF32OrBF16;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
// Define slightly more readable names for the weights and activations. const bool add_bias = layer_weights->layer_config.ff_biases;
const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), model_dim); const float* bias1 =
Mat<const WeightType> w1; add_bias ? layer_weights->vit.linear_0_b.data_scale1() : nullptr;
const float* bias1 = nullptr; const float* output_bias =
float scale = 1.0f; add_bias ? layer_weights->vit.linear_1_b.data_scale1() : nullptr;
Mat<const WeightType> w_output;
const float* output_bias = nullptr;
float output_scale = 1.0f;
auto hidden_activations = MutableMat(activations.C1.All(), ff_hidden_dim);
auto multiplier = MutableMat(activations.C2.All(), ff_hidden_dim);
auto ffw_out = MutableMat(activations.ffw_out.All(), model_dim);
// For some of the weights and activations, it depends on the config where to // Define slightly more readable names for the weights and activations.
// get them from or whether to use them at all. const auto x =
w1 = ConstMat(layer_weights->vit.linear_0_w.data_scale1(), model_dim); ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
bias1 = layer_weights->vit.linear_0_b.data_scale1();
multiplier.ptr = nullptr; auto hidden_activations = RowPtrFromBatch(activations.C1);
w_output = auto ffw_out = RowPtrFromBatch(activations.ffw_out);
ConstMat(layer_weights->vit.linear_1_w.data_scale1(), ff_hidden_dim);
output_bias = layer_weights->vit.linear_1_b.data_scale1(); auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
// Compute the hidden layer activations. // Compute the hidden layer activations.
if (add_bias) { MatMul(x, w1, bias1, activations.env, hidden_activations);
MatMul</*kAddBias=*/true>(num_interleaved, x, w1, scale, bias1,
activations.env, hidden_activations);
} else {
MatMul</*kAddBias=*/false>(num_interleaved, x, w1, scale, bias1,
activations.env, hidden_activations);
}
// Activation (Gelu) and maybe multiply by gate. Store activations in act. // Activation (Gelu), store in act.
Activation(layer_weights->layer_config.activation, hidden_activations.ptr, RowPtrF multiplier = RowPtrF(nullptr, 0);
multiplier.ptr, ff_hidden_dim * num_interleaved); Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
multiplier.Row(0), ff_hidden_dim * num_interleaved);
// Hidden layer -> output layer. // Hidden layer -> output layer.
if (add_bias) { auto activations_mat = MakeConstMat(
MatMul</*kAddBias=*/true>(num_interleaved, ConstMat(hidden_activations), hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));
w_output, output_scale, output_bias,
activations.env, ffw_out); MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
} else {
MatMul</*kAddBias=*/false>(num_interleaved, ConstMat(hidden_activations),
w_output, output_scale, output_bias,
activations.env, ffw_out);
}
} }
// `batch_idx` indicates which row of `x` to write to. // `batch_idx` indicates which row of `x` to write to.
@ -853,7 +795,7 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
// Image tokens just need to be copied. // Image tokens just need to be copied.
if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) { if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx), hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
x.Len() * sizeof(x.Const()[0])); x.Cols() * sizeof(x.Const()[0]));
return; return;
} }
@ -942,7 +884,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
// the Big Vision codebase. See // the Big Vision codebase. See
// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py // github.com/google-research/big_vision/blob/main/big_vision/models/vit.py
// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and // TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and
// try mergig this with TransformerLayer. // try merging this with TransformerLayer.
template <typename T> template <typename T>
HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
const LayerWeightsPtrs<T>* layer_weights, const LayerWeightsPtrs<T>* layer_weights,
@ -953,7 +895,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
auto& x = activations.x; auto& x = activations.x;
HWY_DASSERT(x.BatchSize() == num_tokens); HWY_DASSERT(x.BatchSize() == num_tokens);
HWY_DASSERT(x.Len() == model_dim); HWY_DASSERT(x.Cols() == model_dim);
// y = nn.LayerNorm()(x) // y = nn.LayerNorm()(x)
// y ~ pre_att_rms_out // y ~ pre_att_rms_out
@ -1106,7 +1048,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
const size_t patch_size = patch_width * patch_width * 3; const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() == HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
patch_size * model_dim); patch_size * model_dim);
HWY_DASSERT(activations.x.Len() == model_dim); HWY_DASSERT(activations.x.Cols() == model_dim);
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len); std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
image_patches[i] = hwy::AllocateAligned<float>(patch_size); image_patches[i] = hwy::AllocateAligned<float>(patch_size);
@ -1118,11 +1060,11 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// This could be done as one MatMul like: // This could be done as one MatMul like:
// RowVectorBatch<float> image_patches(kSeqLen, kPatchSize); // RowVectorBatch<float> image_patches(kSeqLen, kPatchSize);
// [Get patches] // [Get patches]
// MatMul</*kAdd=*/true>( // MatMul(
// kVitSeqLen, ConstMat(image_patches.All(), kPatchSize), // MatFromBatch(kVitSeqLen, image_patches),
// ConstMat(weights.vit_img_embedding_kernel.data_scale1(), kPatchSize), // MatFromWeights(weights.vit_img_embedding_kernel),
// /*scale=*/1.0f, weights.vit_img_embedding_bias.data_scale1(), // weights.vit_img_embedding_bias.data_scale1(), activations.env,
// activations.env, MutableMat(activations.x.All(), kVitModelDim)); // RowPtrF(activations.x.All(), kVitModelDim));
// However, MatMul currently requires that // However, MatMul currently requires that
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0 // A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
// which is not the case here. We should relax that requirement on MatMul and // which is not the case here. We should relax that requirement on MatMul and
@ -1163,11 +1105,10 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
activations.x.All(), vit_model_dim); activations.x.All(), vit_model_dim);
// Apply head embedding into image_tokens of size of the LLM kModelDim. // Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul</*kAdd=*/true>( MatMul(ConstMatFromBatch(num_tokens, activations.x),
num_tokens, ConstMat(activations.x.All(), vit_model_dim), ConstMatFromWeights(weights.vit_img_head_kernel),
ConstMat(weights.vit_img_head_kernel.data_scale1(), vit_model_dim), weights.vit_img_head_bias.data_scale1(), activations.env,
/*scale=*/1.0f, weights.vit_img_head_bias.data_scale1(), activations.env, RowPtrFromBatch(image_tokens));
MutableMat(image_tokens.All(), weights.weights_config.model_dim));
} }
// Generates one token for each query. `queries_token` is the previous token // Generates one token for each query. `queries_token` is the previous token
@ -1299,7 +1240,6 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
const QueriesPos& queries_prefix_end, const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const KVCaches& kv_caches, const size_t query_idx_start, const KVCaches& kv_caches,
TimingInfo& timing_info) { TimingInfo& timing_info) {
const size_t model_dim = model.Config().model_dim;
const size_t vocab_size = model.Config().vocab_size; const size_t vocab_size = model.Config().vocab_size;
const ModelWeightsPtrs<T>& weights = *model.GetWeightsOfType<T>(); const ModelWeightsPtrs<T>& weights = *model.GetWeightsOfType<T>();
@ -1387,11 +1327,10 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
{ {
PROFILER_ZONE("Gen.EmbeddingMatmul"); PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations. // Compute logits from last layer activations.
MatMul</*kAdd=*/false>( MatMul(ConstMatFromBatch(num_queries, activations.x),
num_queries, ConstMat(activations.x.All(), model_dim), ConstMatFromWeights(weights.embedder_input_embedding),
ConstMat(weights.embedder_input_embedding.data(), model_dim), /*add=*/nullptr, activations.env,
weights.embedder_input_embedding.scale(), /*add=*/nullptr, RowPtrFromBatch(activations.logits));
activations.env, MutableMat(activations.logits.All(), vocab_size));
} }
PROFILER_ZONE("Gen.Softcap+Sample+Stream"); PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {

View File

@ -35,7 +35,6 @@
#include "util/threading.h" #include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" // also uses SIMD
namespace gcpp { namespace gcpp {
@ -119,12 +118,12 @@ struct GenerateImageTokensT {
void Gemma::Generate(const RuntimeConfig& runtime_config, void Gemma::Generate(const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end, const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, TimingInfo& timing_info) { KVCache& kv_cache, TimingInfo& timing_info) {
if (runtime_config.use_spinning) pools_.StartSpinning(); pools_.MaybeStartSpinning(runtime_config.use_spinning);
model_.CallForModelWeight<GenerateSingleT>( model_.CallForModelWeight<GenerateSingleT>(
runtime_config, prompt, pos, prefix_end, kv_cache, pools_, timing_info); runtime_config, prompt, pos, prefix_end, kv_cache, pools_, timing_info);
if (runtime_config.use_spinning) pools_.StopSpinning(); pools_.MaybeStopSpinning(runtime_config.use_spinning);
} }
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
@ -141,23 +140,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size()); QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
} }
if (runtime_config.use_spinning) pools_.StartSpinning(); pools_.MaybeStartSpinning(runtime_config.use_spinning);
model_.CallForModelWeight<GenerateBatchT>( model_.CallForModelWeight<GenerateBatchT>(
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end, runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
kv_caches, pools_, timing_info); kv_caches, pools_, timing_info);
if (runtime_config.use_spinning) pools_.StopSpinning(); pools_.MaybeStopSpinning(runtime_config.use_spinning);
} }
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens) { const Image& image, ImageTokens& image_tokens) {
if (runtime_config.use_spinning) pools_.StartSpinning(); pools_.MaybeStartSpinning(runtime_config.use_spinning);
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image, model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
image_tokens, pools_); image_tokens, pools_);
if (runtime_config.use_spinning) pools_.StopSpinning(); pools_.MaybeStopSpinning(runtime_config.use_spinning);
} }
// Non-template functions moved from gemma-inl.h to avoid ODR violations. // Non-template functions moved from gemma-inl.h to avoid ODR violations.

View File

@ -121,7 +121,11 @@ struct RuntimeConfig {
const ImageTokens *image_tokens = nullptr; const ImageTokens *image_tokens = nullptr;
// Whether to use thread spinning to reduce barrier synchronization latency. // Whether to use thread spinning to reduce barrier synchronization latency.
bool use_spinning = true; // Mutable so we can change kDefault to kTrue/kFalse during Generate, because
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
// default decision is likely sufficient because it is based on whether
// threads are successfully pinned.
mutable Tristate use_spinning = Tristate::kDefault;
// End-of-sequence token. // End-of-sequence token.
int eos_id = EOS_ID; int eos_id = EOS_ID;

View File

@ -16,7 +16,6 @@
// Command line text interface to gemma. // Command line text interface to gemma.
#include <iostream> #include <iostream>
#include <memory>
#include <random> #include <random>
#include <string> #include <string>
#include <string_view> #include <string_view>
@ -79,8 +78,8 @@ std::string GetPrompt(std::istream& input, int verbosity,
} }
// The main Read-Eval-Print Loop. // The main Read-Eval-Print Loop.
void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
int verbosity, const AcceptFunc& accept_token, const InferenceArgs& args, const AcceptFunc& accept_token,
std::string& eot_line) { std::string& eot_line) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
size_t abs_pos = 0; // across turns size_t abs_pos = 0; // across turns
@ -92,17 +91,18 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
const bool have_image = !args.image_file.path.empty(); const bool have_image = !args.image_file.path.empty();
Image image; Image image;
std::unique_ptr<ImageTokens> image_tokens; ImageTokens image_tokens;
if (have_image) { if (have_image) {
image_tokens = std::make_unique<ImageTokens>( image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim); model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path)); HWY_ASSERT(image.ReadPPM(args.image_file.path));
image.Resize(); image.Resize();
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen}; RuntimeConfig runtime_config = {
.verbosity = app.verbosity, .gen = &gen, .use_spinning = app.spin};
double image_tokens_start = hwy::platform::Now(); double image_tokens_start = hwy::platform::Now();
model.GenerateImageTokens(runtime_config, image, *image_tokens); model.GenerateImageTokens(runtime_config, image, image_tokens);
if (verbosity >= 1) { if (app.verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start; double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr, fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms\n", "\n\n[ Timing info ] Image token generation took: %d ms\n",
@ -122,7 +122,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
abs_pos = 0; abs_pos = 0;
InitGenerator(args, gen); InitGenerator(args, gen);
} }
if (verbosity >= 2) { if (app.verbosity >= 2) {
std::cout << "\n[ End ]\n"; std::cout << "\n[ End ]\n";
} }
} else { } else {
@ -133,7 +133,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
if (tokens_generated_this_turn == prompt_size + 1) { if (tokens_generated_this_turn == prompt_size + 1) {
// first token of response // first token of response
token_text.erase(0, token_text.find_first_not_of(" \t\n")); token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (verbosity >= 1) { if (app.verbosity >= 1) {
std::cout << "\n\n"; std::cout << "\n\n";
} }
} }
@ -144,7 +144,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
while (true) { // Loop until user quits. while (true) { // Loop until user quits.
tokens_generated_this_turn = 0; tokens_generated_this_turn = 0;
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line); std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
if (!std::cin) return; if (!std::cin) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars. // If !eot_line.empty(), we append \n, so only look at the first 2 chars.
if (prompt_string.size() >= 2 && prompt_string[0] == '%') { if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
@ -171,18 +171,17 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
} }
} }
TimingInfo timing_info = {.verbosity = verbosity}; TimingInfo timing_info = {.verbosity = app.verbosity};
RuntimeConfig runtime_config = { RuntimeConfig runtime_config = {.verbosity = app.verbosity,
.verbosity = verbosity,
.gen = &gen, .gen = &gen,
.stream_token = stream_token, .stream_token = stream_token,
.accept_token = accept_token, .accept_token = accept_token,
}; .use_spinning = app.spin};
args.CopyTo(runtime_config); args.CopyTo(runtime_config);
size_t prefix_end = 0; size_t prefix_end = 0;
if (have_image) { if (have_image) {
runtime_config.image_tokens = image_tokens.get(); runtime_config.image_tokens = &image_tokens;
prompt.insert(prompt.begin(), image_tokens->BatchSize(), 0); prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
prompt_size = prompt.size(); prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma. // The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726. // See Figure 2 of https://arxiv.org/abs/2407.07726.
@ -237,8 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
std::cout << "\n" << instructions << "\n"; std::cout << "\n" << instructions << "\n";
} }
ReplGemma(model, kv_cache, inference, app.verbosity, AcceptFunc(), ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line);
app.eot_line);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -95,11 +95,11 @@ struct LayerWeightsPtrs {
config.model_dim}, config.model_dim},
.qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads), .qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads),
config.qkv_dim}, config.qkv_dim},
.linear_0_w = {"linear_0_w", config.model_dim, .linear_0_w = {"linear_0_w", config.ff_hidden_dim,
config.ff_hidden_dim},
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
.linear_1_w = {"linear_1_w", config.ff_hidden_dim,
config.model_dim}, config.model_dim},
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
.linear_1_w = {"linear_1_w", config.model_dim,
config.ff_hidden_dim},
.linear_1_b = {"linear_1_b", 1, config.model_dim}, .linear_1_b = {"linear_1_b", 1, config.model_dim},
.layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim}, .layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim},
.layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim}, .layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim},
@ -349,14 +349,13 @@ struct ModelWeightsPtrs {
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim), vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim), vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim), vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
vit_img_embedding_kernel( vit_img_embedding_kernel("img_emb_kernel",
"img_emb_kernel",
config.patch_width * config.patch_width * 3, config.patch_width * config.patch_width * 3,
config.vit_model_dim), config.vit_model_dim),
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim), vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
vit_img_head_bias("img_head_bias", 1, config.model_dim), vit_img_head_bias("img_head_bias", 1, config.model_dim),
vit_img_head_kernel("img_head_kernel", config.vit_model_dim, vit_img_head_kernel("img_head_kernel", config.model_dim,
config.model_dim), config.vit_model_dim),
scale_names(config.scale_names), scale_names(config.scale_names),
weights_config(config) { weights_config(config) {
c_layers.reserve(config.layer_configs.size()); c_layers.reserve(config.layer_configs.size());

View File

@ -1011,14 +1011,14 @@ struct TestShortDotsT {
// hence they require padding to one vector. // hence they require padding to one vector.
const size_t padded_num = hwy::RoundUpTo(num, N); const size_t padded_num = hwy::RoundUpTo(num, N);
const size_t packed_num = CompressedArrayElements<Packed>(num); const size_t packed_num = CompressedArrayElements<Packed>(num);
RowVectorBatch<float> raw_w(1, padded_num); RowVectorBatch<float> raw_w(Extents2D(1, padded_num));
RowVectorBatch<float> raw_v(1, padded_num); RowVectorBatch<float> raw_v(Extents2D(1, padded_num));
RowVectorBatch<Packed> weights(1, packed_num); RowVectorBatch<Packed> weights(Extents2D(1, packed_num));
const PackedSpan<Packed> w(weights.Batch(0), packed_num); const PackedSpan<Packed> w(weights.Batch(0), packed_num);
RowVectorBatch<T> vectors(1, num); RowVectorBatch<T> vectors(Extents2D(1, num));
const PackedSpan<T> v(vectors.Batch(0), num); const PackedSpan<T> v(vectors.Batch(0), num);
RowVectorBatch<double> bufs(1, num); RowVectorBatch<double> bufs(Extents2D(1, num));
double* HWY_RESTRICT buf = bufs.Batch(0); double* HWY_RESTRICT buf = bufs.Batch(0);
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) { for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
@ -1107,11 +1107,11 @@ void TestAllDot() {
constexpr size_t kReps = hn::AdjustedReps(40); constexpr size_t kReps = hn::AdjustedReps(40);
const size_t num = 24 * 1024; const size_t num = 24 * 1024;
NestedPools pools(kMaxWorkers - 1, /*pin=*/1, BoundedSlice(0, 1), NestedPools pools(kMaxWorkers - 1, /*pin=*/Tristate::kDefault,
BoundedSlice(0, 1)); BoundedSlice(0, 1), BoundedSlice(0, 1));
RowVectorBatch<float> a(kMaxWorkers, num); RowVectorBatch<float> a(Extents2D(kMaxWorkers, num));
RowVectorBatch<float> b(kMaxWorkers, num); RowVectorBatch<float> b(Extents2D(kMaxWorkers, num));
RowVectorBatch<double> bufs(kMaxWorkers, num); RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num));
std::array<DotStats, kMaxWorkers> all_stats; std::array<DotStats, kMaxWorkers> all_stats;
pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) { pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {

View File

@ -16,8 +16,9 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include "compression/compress.h" // IWYU pragma: keep, b/conditionally used
#include "ops/matmul.h" // IWYU pragma: export #include "ops/matmul.h" // IWYU pragma: export
#include "util/allocator.h"
#include "util/basics.h"
// Include guard for (potentially) SIMD code. // Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE) #if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
@ -30,7 +31,7 @@
#include "hwy/highway.h" #include "hwy/highway.h"
// After highway.h // After highway.h
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
#include "hwy/contrib/math/math-inl.h" #include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
@ -53,38 +54,20 @@ constexpr size_t kRegCols = 4;
// generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0). // generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
constexpr size_t kRegRows = kRegCols; constexpr size_t kRegRows = kRegCols;
// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are // Loads two vectors at a time with element type hn::TFromD<DR> from a row of
// more efficient than f32 * f32 + f32 because they process twice as many lanes // transposed B. Called in a loop over col_ab. No bounds checking because
// at a time. Any combination of A and B can be bf16: activations may already be // `kRow` is from B columns, which we checked is a multiple of `kRegCols`.
// bf16, and weights can be decompressed to bf16.
//
// The corresponding op is `ReorderWidenMulAccumulate`, and it is always
// supported, but only useful if it returns a single vector of pairwise sums
// `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReorderWidenMulAccumulate`
// insteads return `a[1] * b[1]` in its `sum1` output. We cannot afford to keep
// a `sum1` for each of the `kRegRows * kRegCols` C vectors, and it would be
// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B
// to bf16 if the native op is available. This will actually demote f32
// activations to bf16. Otherwise, we decompress to f32 and use normal FMA.
using MulT = hwy::If<HWY_NATIVE_DOT_BF16, BF16, float>;
// Loads two vectors at a time with element type MulT from a row of transposed
// B. Called in a loop over col_ab. No bounds checking because `kRow` is
// actually from B columns, which we checked is a multiple of `kRegCols`.
template <size_t kRow, typename MatTB> template <size_t kRow, typename MatTB>
class BRow { class BRow {
static_assert(kRow < kRegRows); // which unrolled instance we are static_assert(kRow < kRegRows); // which unrolled instance we are
public: public:
BRow(const Mat<const MatTB>& B, size_t row_b, size_t cols_c) BRow(const ConstMat<MatTB>& B, size_t row_b)
// B.cols * C.cols is the total number of elements, required for : B_(MakeSpan(B.ptr, B.ofs + B.Extents().Area())),
// PackedSpan::BoundsCheck. B_ofs_(B.Row(HWY_MIN(row_b + kRow, B.Extents().rows - 1))) {}
: B_(MakeSpan(B.ptr, B.ofs + B.cols * cols_c)),
B_ofs_(B.Row(row_b + kRow)) {}
template <class DM, class VM = hn::Vec<DM>> template <class DR, class VR = hn::Vec<DR>>
HWY_INLINE void Load2(DM d, size_t col_ab, VM& b0, VM& b1) const { HWY_INLINE void Load2(DR d, size_t col_ab, VR& b0, VR& b1) const {
static_assert(hwy::IsSame<hn::TFromD<DM>, MulT>());
Decompress2(d, B_, B_ofs_ + col_ab, b0, b1); Decompress2(d, B_, B_ofs_ + col_ab, b0, b1);
} }
@ -93,11 +76,11 @@ class BRow {
const size_t B_ofs_; const size_t B_ofs_;
}; };
// Loads *two* row vectors from A via `Decompress2`, multiplies element-wise // Loads *two* row vectors from A via `Decompress2`, widens to f32, multiplies
// with `kRegRows` x 2 row vectors from transposed B, and adds them to // element-wise with `kRegRows` x 2 row vectors from transposed B, and adds
// `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a subset of // them to `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a
// the terms of the dot products that make up the MatMul result at `r,c`. // subset of the terms of the dot products that make up the MatMul result at
// No-op for the bottom-most tile where kRow >= kNumRows. // `r,c`. No-op for the bottom-most rows whose `kRow >= kNumRows`.
// //
// This approach is atypical because it requires a horizontal sum, for which we // This approach is atypical because it requires a horizontal sum, for which we
// introduce a fast and new(?) vector-length agnostic 'transpose', see // introduce a fast and new(?) vector-length agnostic 'transpose', see
@ -107,22 +90,24 @@ class BRow {
// - `Decompress2` decompresses two vectors at a time; // - `Decompress2` decompresses two vectors at a time;
// - B is column-major, so unit-stride SIMD loads return a column, not values // - B is column-major, so unit-stride SIMD loads return a column, not values
// from different columns, i.e. a row. // from different columns, i.e. a row.
// Both could be fixed in a packing stage, which is not implemented yet, and // - `ReorderWidenMulAccumulate` is important for bf16 performance, but its
// might not be necessary otherwise. However, `ReorderWidenMulAccumulate` is // pairwise adds would add together unrelated terms.
// important for bf16 performance and incompatible with the conventional // The first two could be fixed in a packing stage, which is not implemented
// approach, because its pairwise adds would add together unrelated terms. // yet, and might not be necessary otherwise. The third seems a fundamental
// By contrast, pairwise adds are fine when our C lanes are the terms of a // mismatch. However, pairwise adds are fine in our setting because C lanes are
// single dot product, which can be reordered or pre-reduced. // the terms of a single dot product, which can be reordered or pre-reduced.
template <size_t kRow, typename MatTA> template <size_t kRow, typename MatTA>
class ALoadAccumulate { class ALoadAccumulate {
static_assert(kRow < kRegRows); // which unrolled instance we are
public: public:
ALoadAccumulate(const Mat<const MatTA>& A, size_t row_ac, size_t batch_size) static_assert(kRow < kRegRows); // which unrolled instance we are
// A.cols * batch_size is the total number of elements, required for // `First` and `Next` handle a single row of A, so the horizontal sums of
// PackedSpan::BoundsCheck. // their `C0..3` are the (partial) dot products for 4 consecutive values in
: A_(MakeSpan(A.ptr, A.ofs + A.cols * batch_size)), // one row of C.
A_ofs_(A.Row(row_ac + kRow)) {} static_assert(kRegCols == 4);
ALoadAccumulate(const ConstMat<MatTA>& A, size_t row_ac)
: A_(MakeSpan(A.ptr, A.ofs + A.Extents().Area())),
A_ofs_(A.Row(HWY_MIN(row_ac + kRow, A.Extents().rows - 1))) {}
// First iteration, col_ab = 0: initialize C0..3 instead of updating them. // First iteration, col_ab = 0: initialize C0..3 instead of updating them.
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)> template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
@ -161,20 +146,27 @@ class ALoadAccumulate {
Decompress2(dm, A_, A_ofs_, a0, a1); Decompress2(dm, A_, A_ofs_, a0, a1);
const DF df; const DF df;
VF unused_sum1 = hn::Zero(df);
static_assert(kRegCols == 4); static_assert(kRegCols == 4);
C0 = hn::WidenMulPairwiseAdd(df, a0, b00); C0 = hn::WidenMulPairwiseAdd(df, a0, b00);
C1 = hn::WidenMulPairwiseAdd(df, a0, b10); C1 = hn::WidenMulPairwiseAdd(df, a0, b10);
C2 = hn::WidenMulPairwiseAdd(df, a0, b20); C2 = hn::WidenMulPairwiseAdd(df, a0, b20);
C3 = hn::WidenMulPairwiseAdd(df, a0, b30); C3 = hn::WidenMulPairwiseAdd(df, a0, b30);
if constexpr (HWY_NATIVE_DOT_BF16) {
// Native ReorderWidenMulAccumulate adds to C0..3 for free.
VF unused_sum1 = hn::Zero(df);
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
// Ensure sum1 was indeed unused. // Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
} else {
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01));
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11));
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21));
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31));
}
} }
} }
@ -217,9 +209,11 @@ class ALoadAccumulate {
Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1); Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
const DF df; const DF df;
hn::Vec<DF> unused_sum1 = hn::Zero(df);
static_assert(kRegCols == 4); static_assert(kRegCols == 4);
if constexpr (HWY_NATIVE_DOT_BF16) {
// Native ReorderWidenMulAccumulate adds to C0..3 for free.
VF unused_sum1 = hn::Zero(df);
C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1); C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1);
C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1); C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1);
C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1); C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1);
@ -228,9 +222,18 @@ class ALoadAccumulate {
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
// Ensure sum1 was indeed unused. // Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
} else {
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a0, b00));
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a0, b10));
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a0, b20));
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a0, b30));
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01));
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11));
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21));
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31));
}
} }
} }
@ -356,116 +359,113 @@ class AddHorizontalSums {
// Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a // Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a
// *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c). // *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c).
// TODO: loop over sections instead of full rows and accumulate into `tile_c`. // TODO: loop over sections instead of full rows and accumulate into `tile_c`.
// `buf` is 16 vectors of thread-local storage.
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB> template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
HWY_INLINE void MatMulTile(const size_t batch_size, const Mat<const MatTA>& A, HWY_INLINE void MatMulTile(const ConstMat<MatTA>& A, const size_t row_ac,
const Mat<const MatTB>& B, const size_t row_ac, const ConstMat<MatTB>& B, const size_t row_b_col_c,
const size_t row_b_col_c, const float scale, const float scale, const float* HWY_RESTRICT add,
const float* HWY_RESTRICT add, float* HWY_RESTRICT buf, const RowPtr<float>& C) {
float* HWY_RESTRICT buf, const Mat<float>& C) { // Decompress A and B to which type, which will then be widened to f32,
// For 'decompressing' A and B into BF16 or float. // multiplied, added once into f32, then promoted to f64 and accumulated.
const hn::ScalableTag<MulT> dm; // NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are
using VM = hn::Vec<decltype(dm)>; // more efficient than f32 * f32 + f32 because they process twice as many
const size_t NM = hn::Lanes(dm); // lanes at a time. If available, we definitely want to use them. Otherwise,
// bf16 is still worthwhile if A (activations) are bf16: SFP weights are
// cheaper to decode to bf16, relative to the minor extra cost of promoting
// bf16 when multiplying. However, if A is f32, demoting to bf16 can be
// expensive unless we also have native bf16 dot.
using Raw = hwy::If<HWY_NATIVE_DOT_BF16 || !IsF32<MatTA>(), BF16, float>;
const hn::ScalableTag<Raw> dr;
using VR = hn::Vec<decltype(dr)>;
const size_t NR = hn::Lanes(dr);
const Range1D cols_ab(0, A.Extents().cols);
HWY_DASSERT(row_ac + kNumRows <= A.Extents().rows);
HWY_DASSERT(row_b_col_c + kNumRows <= B.Extents().rows);
HWY_DASSERT(cols_ab.end() % (2 * NR) == 0);
static_assert(kRegRows == 4); static_assert(kRegRows == 4);
const BRow<0, MatTB> b_row0(B, row_b_col_c, C.cols); const BRow<0, MatTB> b_row0(B, row_b_col_c);
const BRow<1, MatTB> b_row1(B, row_b_col_c, C.cols); const BRow<1, MatTB> b_row1(B, row_b_col_c);
const BRow<2, MatTB> b_row2(B, row_b_col_c, C.cols); const BRow<2, MatTB> b_row2(B, row_b_col_c);
const BRow<3, MatTB> b_row3(B, row_b_col_c, C.cols); const BRow<3, MatTB> b_row3(B, row_b_col_c);
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac, batch_size); const ALoadAccumulate<0, MatTA> a_row0(A, row_ac);
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac, batch_size); const ALoadAccumulate<1, MatTA> a_row1(A, row_ac);
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac, batch_size); const ALoadAccumulate<2, MatTA> a_row2(A, row_ac);
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac, batch_size); const ALoadAccumulate<3, MatTA> a_row3(A, row_ac);
const hn::Repartition<float, decltype(dm)> df; const hn::Repartition<float, decltype(dr)> df;
using VF = hn::Vec<decltype(df)>; using VF = hn::Vec<decltype(df)>;
VF C00, C01, C02, C03; VF C00, C01, C02, C03;
VF C10, C11, C12, C13; VF C10, C11, C12, C13;
VF C20, C21, C22, C23; VF C20, C21, C22, C23;
VF C30, C31, C32, C33; VF C30, C31, C32, C33;
size_t col_ab = cols_ab.begin();
{ // First iteration initializes the `Crc` vectors. { // First iteration initializes the `Crc` vectors.
VM b00, b01, b10, b11, b20, b21, b30, b31; VR b00, b01, b10, b11, b20, b21, b30, b31;
b_row0.Load2(dm, /*col_ab=*/0, b00, b01); b_row0.Load2(dr, col_ab, b00, b01);
b_row1.Load2(dm, /*col_ab=*/0, b10, b11); b_row1.Load2(dr, col_ab, b10, b11);
b_row2.Load2(dm, /*col_ab=*/0, b20, b21); b_row2.Load2(dr, col_ab, b20, b21);
b_row3.Load2(dm, /*col_ab=*/0, b30, b31); b_row3.Load2(dr, col_ab, b30, b31);
a_row0.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31, a_row0.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
C00, C01, C02, C03); C00, C01, C02, C03);
a_row1.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31, a_row1.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
C10, C11, C12, C13); C10, C11, C12, C13);
a_row2.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31, a_row2.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
C20, C21, C22, C23); C20, C21, C22, C23);
a_row3.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31, a_row3.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
C30, C31, C32, C33); C30, C31, C32, C33);
col_ab += 2 * NR;
} }
// `2 * NM` per iteration because `Load2` returns two vectors. // `2 * NR` per iteration because `Load2` returns two vectors.
HWY_UNROLL(1) HWY_UNROLL(1)
for (size_t col_ab = 2 * NM; col_ab <= A.cols - 2 * NM; col_ab += 2 * NM) { for (; col_ab < cols_ab.end(); col_ab += 2 * NR) {
VM b00, b01, b10, b11, b20, b21, b30, b31; VR b00, b01, b10, b11, b20, b21, b30, b31;
b_row0.Load2(dm, col_ab, b00, b01); b_row0.Load2(dr, col_ab, b00, b01);
b_row1.Load2(dm, col_ab, b10, b11); b_row1.Load2(dr, col_ab, b10, b11);
b_row2.Load2(dm, col_ab, b20, b21); b_row2.Load2(dr, col_ab, b20, b21);
b_row3.Load2(dm, col_ab, b30, b31); b_row3.Load2(dr, col_ab, b30, b31);
a_row0.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21, a_row0.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C00, C01, C02, C03); b30, b31, C00, C01, C02, C03);
a_row1.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21, a_row1.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C10, C11, C12, C13); b30, b31, C10, C11, C12, C13);
a_row2.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21, a_row2.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C20, C21, C22, C23); b30, b31, C20, C21, C22, C23);
a_row3.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21, a_row3.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C30, C31, C32, C33); b30, b31, C30, C31, C32, C33);
} }
// TODO: hoist into outer loop. // TODO: hoist into outer loop.
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c; float* HWY_RESTRICT C_tile = C.Row(row_ac) + row_b_col_c;
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.stride); InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.Stride());
AddHorizontalSums<kNumRows>()(df, scale, C00, C01, C02, C03, C10, C11, C12, AddHorizontalSums<kNumRows>()(df, scale, C00, C01, C02, C03, C10, C11, C12,
C13, C20, C21, C22, C23, C30, C31, C32, C33, C13, C20, C21, C22, C23, C30, C31, C32, C33,
buf, C_tile, C.stride); buf, C_tile, C.Stride());
} }
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
//
// `A` is a row-major matrix of shape `(batch_size, A.cols)`.
// `B` is transposed; `B.cols`, which must match `A.cols`, denotes the number of
// rows in the original B, and `C.cols` the number of columns in the original B.
//
// `scale` allows expanding the smaller range of `SfpStream` to the original
// values. When `A` and/or `B` are from CompressedArray, `scale` should be the
// product of their `.scale()` values, otherwise 1.0f.
//
// If `kAdd` is true, the row-vector `add` is added to each row of `C`,
// otherwise `add` is ignored and can be nullptr. A scale for `add` is not
// supported, so make sure its scale is 1.
//
// `C` is a row-major matrix of size `(batch_size, C.cols)`.
//
// Updates 4x4 tiles of C in parallel using a work-stealing thread pool.
// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k.
// Must not be called concurrently with the same `env`.
template <bool kAdd, typename MatTA, typename MatTB> template <bool kAdd, typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A, HWY_NOINLINE void MatMulImpl(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
const Mat<const MatTB>& B, const float scale,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
const Mat<float>& C) { const RowPtr<float>& C) {
// PROFILER_ZONE("Matmul"); // PROFILER_ZONE("Matmul");
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty()); HWY_DASSERT(A.Extents().cols == B.Extents().cols);
HWY_DASSERT(A.cols == B.cols); const size_t batch_size = A.Extents().rows;
HWY_DASSERT(C.Cols() % kRegCols == 0);
HWY_DASSERT(C.Stride() >= C.Cols());
HWY_DASSERT(B.Extents().rows == C.Cols());
// Must be a multiple of two vectors because we Decompress2. const float scale = A.scale * B.scale;
HWY_DASSERT(A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0);
HWY_DASSERT(C.cols % kRegCols == 0);
// We currently write C directly, which touches more memory than fits in L3. // We currently write C directly, which touches more memory than fits in L3.
// TODO: add another level of loops to finish L3-sized pieces of C at a time. // TODO: add another level of loops to finish L3-sized pieces of C at a time.
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows); const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
const size_t tilesX = C.cols / kRegCols; const size_t tilesX = C.Cols() / kRegCols;
env.Pool().Run( env.Pool().Run(
0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR { 0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR {
@ -481,24 +481,45 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
HWY_DASSERT(num_rows != 0); HWY_DASSERT(num_rows != 0);
switch (num_rows) { switch (num_rows) {
case 1: case 1:
MatMulTile<1, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale, MatMulTile<1, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
add, buf, C);
break; break;
case 2: case 2:
MatMulTile<2, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale, MatMulTile<2, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
add, buf, C);
break; break;
case 3: case 3:
MatMulTile<3, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale, MatMulTile<3, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
add, buf, C);
break; break;
default: default:
MatMulTile<4, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale, MatMulTile<4, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
add, buf, C);
} }
}); });
} }
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
//
// `A` is a row-major matrix and `B` is transposed. Its `B.Extents().cols`,
// which must match `A.Extents().cols`, is the number of rows in the original B.
//
// If `add` is non-null, the row-vector `add` is added to each row of `C`.
// A scale for `add` is not supported, so make sure its scale is 1.
//
// `C` is a row-major matrix of size `(A.rows, C.Cols())` with support for
// arbitrary strides.
//
// Updates 4x4 tiles of C in parallel using a work-stealing thread pool.
// Typically `A.rows` is 1..512, `A.Extents().cols` and `B.Extents().rows` are
// 3k or 24k. Must not be called concurrently with the same `env`.
template <typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<float>& C) {
if (add) {
MatMulImpl<true>(A, B, add, env, C);
} else {
MatMulImpl<false>(A, B, nullptr, env, C);
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp

View File

@ -19,73 +19,22 @@
#include <stddef.h> #include <stddef.h>
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "util/basics.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "util/allocator.h" // RowVectorBatch
#include "hwy/per_target.h" // VectorBytes #include "hwy/per_target.h" // VectorBytes
namespace gcpp { namespace gcpp {
// Bundles ptr/size/stride arguments to simplify MatMul call sites. T can be
// const or non-const. Create via ConstMat/MutableMat.
// TODO(rays): Replace with MatPtr and get rid of stride, which is only != cols
// in one place.
template <typename T>
struct Mat {
bool NotEmpty() const {
return ptr != nullptr && cols != 0 && stride >= cols;
}
size_t Row(size_t r) const { return ofs + stride * r; }
T* HWY_RESTRICT ptr;
size_t cols;
// elements between rows, which is typically the same as `cols`.
size_t stride;
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic.
size_t ofs;
};
template <typename T>
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
size_t ofs = 0) {
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
}
template <typename T>
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols, size_t stride,
size_t ofs = 0) {
return Mat<const T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
}
template <typename T>
Mat<const T> ConstMat(Mat<T> mat) {
return ConstMat(mat.ptr, mat.cols, mat.stride, mat.ofs);
}
template <typename T>
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols) {
return MutableMat(ptr, cols, cols);
}
template <typename T>
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols) {
return ConstMat(ptr, cols, cols);
}
// Allocations and threads, shared across MatMul calls. // Allocations and threads, shared across MatMul calls.
class MatMulEnv { class MatMulEnv {
public: public:
MatMulEnv() : pools_(nullptr) {} MatMulEnv() : pools_(nullptr) {}
explicit MatMulEnv(NestedPools& pools) : pools_(&pools) { explicit MatMulEnv(NestedPools& pools) : pools_(&pools) {
const size_t N = hwy::VectorBytes() / sizeof(float); const size_t N = hwy::VectorBytes() / sizeof(float);
buf_ = RowVectorBatch<float>(pools.MaxWorkers(), 16 * N); buf_ = RowVectorBatch<float>(Extents2D(pools.MaxWorkers(), 16 * N));
} }
RowVectorBatch<float>& Buf() { return buf_; } RowVectorBatch<float>& Buf() { return buf_; }

View File

@ -32,6 +32,7 @@
#include "compression/compress.h" #include "compression/compress.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -55,19 +56,23 @@ namespace HWY_NAMESPACE {
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>; using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
template <typename MatT>
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
// Generates inputs: deterministic, within max SfpStream range. // Generates inputs: deterministic, within max SfpStream range.
template <typename MatT, size_t kRows, size_t kCols, template <typename MatT>
class MatPtr = std::unique_ptr<MatStorageT<MatT>>> MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
auto mat = std::make_unique<MatStorageT<MatT>>("test", kRows, kCols); auto mat =
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements()); FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
HWY_ASSERT(content); HWY_ASSERT(content);
const float scale = SfpStream::kMax / (mat->NumElements() + offset); const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) { pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t j = 0; j < kCols; j++) { for (size_t c = 0; c < extents.cols; c++) {
content[i * kCols + j] = content[r * extents.cols + c] =
static_cast<float>((i * kCols + j + offset) * scale); static_cast<float>(r * extents.cols + c) * scale;
} }
}); });
@ -76,185 +81,173 @@ MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) {
return mat; return mat;
} }
template <typename MatT, size_t kRows, size_t kCols, // extents describes the transposed matrix.
class MatPtr = std::unique_ptr<MatStorageT<MatT>>> template <typename MatT>
MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) { MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
MatPtr mat = std::make_unique<MatStorageT<MatT>>("test", kCols, kRows); auto mat =
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements()); FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
const float scale = SfpStream::kMax / (mat->NumElements() + offset); const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) { pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t j = 0; j < kCols; j++) { for (size_t c = 0; c < extents.cols; c++) {
content[j * kRows + i] = content[r * extents.cols + c] =
static_cast<float>((i * kCols + j + offset) * scale); static_cast<float>(c * extents.rows + r) * scale;
} }
}); });
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
// Arbitrary value, different from 1, must match GenerateMatHeap. // Arbitrary value, different from 1, must match GenerateMat.
mat->set_scale(0.6f); mat->set_scale(0.6f);
return mat; return mat;
} }
template <typename MatT, size_t kRows, size_t kCols,
class MatPtr = std::unique_ptr<MatStorageT<MatT>>>
MatPtr GenerateZeroMat(hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat = std::make_unique<MatStorageT<MatT>>("Array", kRows, kCols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
HWY_ASSERT(content);
pool.Run(0, kRows, [&](const size_t i, size_t thread) {
hwy::ZeroBytes(&content[i * kCols], kCols * sizeof(content[0]));
});
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
mat->set_scale(1.2f); // Arbitrary value, different from 1.
return mat;
}
// Returns 1-norm, used for estimating tolerable numerical differences. // Returns 1-norm, used for estimating tolerable numerical differences.
double MaxColAbsSum(const float* HWY_RESTRICT a, size_t rows, size_t cols) { double MaxColAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
double max_col_abs_sum = 0.0; double max_col_abs_sum = 0.0;
for (size_t c = 0; c < cols; c++) { for (size_t c = 0; c < extents.cols; c++) {
double col_abs_sum = 0.0; double col_abs_sum = 0.0;
for (size_t r = 0; r < rows; r++) { for (size_t r = 0; r < extents.rows; r++) {
col_abs_sum += hwy::ScalarAbs(a[r * cols + c]); col_abs_sum += hwy::ScalarAbs(a[r * extents.cols + c]);
} }
max_col_abs_sum = HWY_MAX(max_col_abs_sum, col_abs_sum); max_col_abs_sum = HWY_MAX(max_col_abs_sum, col_abs_sum);
} }
return max_col_abs_sum; return max_col_abs_sum;
} }
// B is already transposed.
template <typename MatTA, typename MatTB> template <typename MatTA, typename MatTB>
void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b, void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
const MatTA* HWY_RESTRICT pa, const RowPtrF& C_slow, const RowPtrF& C) {
const MatTB* HWY_RESTRICT pb_trans,
const float* HWY_RESTRICT expected_c,
const float* HWY_RESTRICT actual_c) {
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
const size_t num_a = rows_ac * cols_ab; const size_t num_a = A.extents.Area();
const size_t num_b = cols_c_rows_b * cols_ab; const size_t num_b = B.extents.Area();
HWY_ASSERT(num_a % hn::Lanes(df) == 0); // for DecompressAndZeroPad HWY_ASSERT(num_a % hn::Lanes(df) == 0); // for DecompressAndZeroPad
HWY_ASSERT(num_b % hn::Lanes(df) == 0); // for DecompressAndZeroPad HWY_ASSERT(num_b % hn::Lanes(df) == 0); // for DecompressAndZeroPad
const size_t num_c = rows_ac * cols_c_rows_b;
FloatPtr a = hwy::AllocateAligned<float>(num_a); FloatPtr a = hwy::AllocateAligned<float>(num_a);
FloatPtr b_trans = hwy::AllocateAligned<float>(num_b); FloatPtr b_trans = hwy::AllocateAligned<float>(num_b);
HWY_ASSERT(a && b_trans); HWY_ASSERT(a && b_trans);
DecompressAndZeroPad(df, MakeSpan(pa, num_a), 0, a.get(), num_a); HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
DecompressAndZeroPad(df, MakeSpan(pb_trans, num_b), 0, b_trans.get(), num_b); DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b);
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) * const double norm = MaxColAbsSum(a.get(), A.Extents()) *
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab); MaxColAbsSum(b_trans.get(), B.Extents());
// Dot(float,BF16) rounds both to BF16. // Dot(float,BF16) rounds both to BF16.
using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>; using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>;
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>()); const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>());
const double tolerance = 200.0 * norm * epsilon; const double tolerance = 200.0 * norm * epsilon;
for (size_t idx = 0; idx < num_c; idx++) { for (size_t r = 0; r < A.extents.rows; r++) {
const double expected_value = expected_c[idx]; const float* expected_row = C_slow.Row(r);
const double actual_value = actual_c[idx]; const float* actual_row = C.Row(r);
for (size_t c = 0; c < B.extents.rows; c++) {
const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]);
if (!(expected_value - tolerance <= actual_value && if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) { actual_value <= expected_value + tolerance)) {
fprintf( fprintf(
stderr, stderr,
"expected[%lu]: %f, actual[%lu]: %f, norm %f eps %E tolerance %f\n", "(%zu,%zu): expected %f, actual %f, norm %f eps %E tolerance %f\n",
idx, expected_value, idx, actual_value, norm, epsilon, tolerance); r, c, expected_value, actual_value, norm, epsilon, tolerance);
HWY_ASSERT(0); }
} }
} }
} }
// B is already transposed.
template <typename MatTA, typename MatTB> template <typename MatTA, typename MatTB>
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, HWY_INLINE void MatMulSlow(const ConstMat<MatTA> A, const ConstMat<MatTB> B,
const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_trans, const float scale,
const float* HWY_RESTRICT add_row, MatMulEnv& env, const float* HWY_RESTRICT add_row, MatMulEnv& env,
float* HWY_RESTRICT out) { const RowPtrF& C) {
// MatTA can be any Packed except NuqStream because it uses pointer // MatTA can be any Packed except NuqStream because it uses pointer
// arithmetic, because it is the second argument to Dot, which does not // arithmetic, because it is the second argument to Dot, which does not
// support a v_ofs. // support a v_ofs.
static_assert(sizeof(MatTA) >= sizeof(BF16), "A matrix must be BF16/f32"); static_assert(sizeof(MatTA) >= sizeof(BF16), "A matrix must be BF16/f32");
const float scale = A.scale * B.scale;
const hn::ScalableTag<float> df; // lane type is ignored const hn::ScalableTag<float> df; // lane type is ignored
const PackedSpan<const MatTB> b_span = const PackedSpan<const MatTB> b_span =
MakeSpan(b_trans, cols_a_rows_b * cols_bc); MakeSpan(B.ptr, B.ofs + B.extents.Area());
const Extents2D C_extents(A.extents.rows, C.Cols());
StaticPartitionRowsAndCols( StaticPartitionRowsAndCols(
env.Pools(), rows_ac, cols_bc, sizeof(MatTB), env.Pools(), C_extents, sizeof(MatTB),
[&](size_t /*node*/, hwy::ThreadPool& pool, [&](const Range2D& C_range, const TaskLocation& loc) {
const size_t /*worker_offset*/, const size_t row_begin, loc.cluster.Run(
const size_t row_end, const size_t col_begin, const size_t col_end) { C_range.rows.begin(), C_range.rows.end(),
pool.Run(row_begin, row_end,
[&](const uint64_t row, size_t /*thread*/) { [&](const uint64_t row, size_t /*thread*/) {
for (size_t col = col_begin; col < col_end; ++col) { float* HWY_RESTRICT C_row = C.Row(row);
const float add = add_row ? add_row[col] : 0.0f; for (size_t row_b_col_c : C_range.cols) {
out[row * cols_bc + col] = const float add = add_row ? add_row[row_b_col_c] : 0.0f;
scale * Dot(df, b_span, col * cols_a_rows_b, C_row[row_b_col_c] =
a + row * cols_a_rows_b, cols_a_rows_b) + add + scale * Dot(df, b_span, row_b_col_c * B.extents.cols,
add; A.ptr + A.Row(row), A.extents.cols);
} }
}); });
}); });
} }
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b, void PrintSpeed(const char* algo, const Extents2D& A_extents,
size_t cols_bc, double elapsed) { const Extents2D& B_extents, double elapsed) {
const size_t num_b = cols_a_rows_b * cols_bc; const size_t num_b = B_extents.Area();
// 2x because of FMA. // 2x because of FMA.
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo, fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
elapsed, 2 * 1E-9 * rows_ac * num_b / elapsed); elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
} }
template <size_t kRowsAC, size_t kColsARowsB, size_t kColsBC, bool kAdd, template <typename MatTA, typename MatTB = MatTA>
typename MatTA, typename MatTB = MatTA> void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
void TestMatMul(MatMulEnv& env) { MatMulEnv& env) {
hwy::ThreadPool& pool = env.Pool(); hwy::ThreadPool& pool = env.Pool();
const bool want_bench = kColsBC > 2000; // avoid spam for small matrices const bool want_bench = cols_bc > 2000; // avoid spam for small matrices
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
kRowsAC, kColsARowsB, kColsBC, kAdd, TypeName<MatTA>(), rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
TypeName<MatTB>()); TypeName<MatTB>());
std::unique_ptr<MatStorageT<MatTA>> a = const Extents2D A_extents(rows_ac, cols_a_rows_b);
GenerateMat<MatTA, kRowsAC, kColsARowsB>(0, pool); const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
std::unique_ptr<MatStorageT<MatTB>> b_trans = const Extents2D C_extents(rows_ac, cols_bc);
GenerateTransposedMat<MatTB, kColsARowsB, kColsBC>(0, pool);
FloatPtr c = hwy::AllocateAligned<float>(kRowsAC * kColsBC);
HWY_ASSERT(c);
const float scale = a->scale() * b_trans->scale(); MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
std::unique_ptr<MatStorageT<float>> add; MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
if (kAdd) { RowVectorBatch<float> c_slow_batch(C_extents);
add = GenerateMat<float, 1, kColsBC>(0, pool); RowVectorBatch<float> c_batch(C_extents);
add->set_scale(1.0f); HWY_ASSERT(a && b_trans);
std::unique_ptr<MatStorageT<float>> add_storage;
if (add) {
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
HWY_ASSERT(add_storage);
add_storage->set_scale(1.0f);
} }
std::unique_ptr<MatStorageT<float>> c_slow = const auto A = ConstMatFromWeights(*a);
GenerateZeroMat<float, kRowsAC, kColsBC>(pool); const auto B = ConstMatFromWeights(*b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch);
const RowPtrF C = RowPtrFromBatch(c_batch);
const double start_slow = hwy::platform::Now(); const double start_slow = hwy::platform::Now();
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale, MatMulSlow(A, B, add_row, env, C_slow);
kAdd ? add->data() : nullptr, env, c_slow->data());
if (want_bench) { if (want_bench) {
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC, PrintSpeed("MatMulSlow", A_extents, B_extents,
hwy::platform::Now() - start_slow); hwy::platform::Now() - start_slow);
} }
double min_elapsed = hwy::HighestValue<double>(); double min_elapsed = hwy::HighestValue<double>();
for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) { for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) {
const double start_tiled = hwy::platform::Now(); const double start_tiled = hwy::platform::Now();
MatMul<kAdd>(kRowsAC, ConstMat(a->data(), kColsARowsB), MatMul(A, B, add_row, env, C);
ConstMat(b_trans->data(), kColsARowsB), scale,
kAdd ? add->data_scale1() : nullptr, env,
MutableMat(c.get(), kColsBC));
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled); min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
} }
if (want_bench) { if (want_bench) {
PrintSpeed("MatMul", kRowsAC, kColsARowsB, kColsBC, min_elapsed); PrintSpeed("MatMul", A_extents, B_extents, min_elapsed);
} }
AssertClose(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), AssertClose(A, B, C_slow, C);
c_slow->data(), c.get());
} }
void TestAllMatMul() { void TestAllMatMul() {
@ -264,8 +257,9 @@ void TestAllMatMul() {
return; return;
} }
NestedPools pools(4, /*pin=*/1); NestedPools pools(4, /*pin=*/Tristate::kDefault);
pools.StartSpinning(); Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning);
Allocator::Init(pools.Topology()); Allocator::Init(pools.Topology());
MatMulEnv env(pools); MatMulEnv env(pools);
@ -273,52 +267,54 @@ void TestAllMatMul() {
using SFP = SfpStream; using SFP = SfpStream;
// large-scale test: batch_size=128 is better than 64 or 256 for SKX. // large-scale test: batch_size=128 is better than 64 or 256 for SKX.
TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, SFP>(env); // TestMatMul<F32, SFP>(128, 24576, 3072, /*add=*/false, env);
TestMatMul<128, 3072, 24576, /*kAdd=*/false, F32, SFP>(env); // TestMatMul<F32, SFP>(128, 3072, 24576, /*add=*/false, env);
TestMatMul<1, 24576, 3072, /*kAdd=*/false, F32, F32>(env); TestMatMul<F32, F32>(1, 24576, 3072, /*add=*/false, env);
TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, F32>(env); TestMatMul<F32, F32>(1, 3072, 24576, /*add=*/false, env);
TestMatMul<F32, SFP>(1, 24576, 3072, /*add=*/false, env);
TestMatMul<F32, SFP>(1, 3072, 24576, /*add=*/false, env);
// medium-sized square test - temporarily disabled for faster testing. // medium-sized square test - temporarily disabled for faster testing.
if constexpr (false) { if constexpr (false) {
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env); TestMatMul<F32>(512, 512, 512, /*add=*/false, env);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env); TestMatMul<BF16>(512, 512, 512, /*add=*/true, env);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env); TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env); TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env); TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env); TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env);
} }
// minimal non-square test. kColsARowsB must be at least 2 vectors. // minimal non-square test. kColsARowsB must be at least 2 vectors.
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env); TestMatMul<F32>(35, 128, 32, /*add=*/false, env);
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env); TestMatMul<BF16>(34, 128, 32, /*add=*/true, env);
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env); TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env);
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env); TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env);
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env); TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env);
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env); TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env); TestMatMul<F32>(4, 128, 32, /*add=*/true, env);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env); TestMatMul<BF16>(4, 128, 32, /*add=*/false, env);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env); TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env); TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env); TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env); TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env); TestMatMul<F32>(3, 128, 32, /*add=*/false, env);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env); TestMatMul<BF16>(3, 128, 32, /*add=*/true, env);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env); TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env); TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env); TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env); TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env); TestMatMul<F32>(2, 128, 64, /*add=*/true, env);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env); TestMatMul<BF16>(2, 128, 64, /*add=*/false, env);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env); TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env); TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env); TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env); TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env); TestMatMul<F32>(1, 128, 32, /*add=*/false, env);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env); TestMatMul<BF16>(1, 128, 32, /*add=*/true, env);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env); TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env); TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env); TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env); TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env);
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -389,7 +389,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
void TestRopeAndMulBy() { void TestRopeAndMulBy() {
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B); ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
int dim_qkv = config.layer_configs[0].qkv_dim; int dim_qkv = config.layer_configs[0].qkv_dim;
RowVectorBatch<float> x(1, dim_qkv); RowVectorBatch<float> x(Extents2D(1, dim_qkv));
std::mt19937 gen; std::mt19937 gen;
gen.seed(0x12345678); gen.seed(0x12345678);

View File

@ -14,7 +14,6 @@
// limitations under the License. // limitations under the License.
#include <cstdio> #include <cstdio>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -45,20 +44,20 @@ class PaliGemmaTest : public ::testing::Test {
std::string GemmaReply(const std::string& prompt_text) const; std::string GemmaReply(const std::string& prompt_text) const;
void TestQuestions(const char* kQA[][2], size_t num_questions); void TestQuestions(const char* kQA[][2], size_t num_questions);
std::unique_ptr<ImageTokens> image_tokens_; ImageTokens image_tokens_;
}; };
void PaliGemmaTest::InitVit(const std::string& path) { void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetModel(), nullptr); ASSERT_NE(s_env->GetModel(), nullptr);
Gemma& model = *(s_env->GetModel()); Gemma& model = *(s_env->GetModel());
image_tokens_ = std::make_unique<ImageTokens>( image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim); model.GetModelConfig().model_dim));
Image image; Image image;
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path)); HWY_ASSERT(image.ReadPPM(path));
image.Resize(); image.Resize();
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()}; RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
model.GenerateImageTokens(runtime_config, image, *image_tokens_); model.GenerateImageTokens(runtime_config, image, image_tokens_);
} }
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
@ -67,7 +66,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
RuntimeConfig runtime_config = {.max_generated_tokens = 512, RuntimeConfig runtime_config = {.max_generated_tokens = 512,
.verbosity = 0, .verbosity = 0,
.gen = &s_env->MutableGen()}; .gen = &s_env->MutableGen()};
runtime_config.image_tokens = image_tokens_.get(); runtime_config.image_tokens = &image_tokens_;
size_t abs_pos = 0; size_t abs_pos = 0;
std::string mutable_prompt = prompt_text; std::string mutable_prompt = prompt_text;
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt); std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
@ -79,7 +78,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
return true; return true;
}; };
runtime_config.stream_token = stream_token, runtime_config.stream_token = stream_token,
tokens.insert(tokens.begin(), image_tokens_->BatchSize(), 0); tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0);
size_t num_tokens = tokens.size(); size_t num_tokens = tokens.size();
size_t prefix_end = num_tokens; size_t prefix_end = num_tokens;
runtime_config.prefill_tbatch_size = num_tokens; runtime_config.prefill_tbatch_size = num_tokens;

View File

@ -162,20 +162,19 @@ static void BindMemory(void* ptr, size_t bytes, size_t node) {
static void BindMemory(void*, size_t, size_t) {} static void BindMemory(void*, size_t, size_t) {}
#endif // GEMMA_NUMA && HWY_OS_LINUX #endif // GEMMA_NUMA && HWY_OS_LINUX
void BindTensor(NestedPools& nested, size_t rows, size_t cols, void BindTensor(NestedPools& nested, const Extents2D& extents,
size_t bytes_per_col, void* ptr) { size_t bytes_per_col, void* ptr) {
if (!Allocator::UseNUMA()) return; if (!Allocator::UseNUMA()) return;
uint8_t* p8 = static_cast<uint8_t*>(ptr); uint8_t* p8 = static_cast<uint8_t*>(ptr);
const size_t bytes_per_row = cols * bytes_per_col; const size_t bytes_per_row = extents.cols * bytes_per_col;
StaticPartitionRowsAndCols( StaticPartitionRowsAndCols(
nested, rows, cols, bytes_per_col, nested, extents, bytes_per_col,
[&](size_t node, hwy::ThreadPool&, const size_t /*worker_offset*/, [&](const Range2D& r, const TaskLocation& loc) {
const size_t row_begin, const size_t row_end, const size_t col_begin, for (size_t row : r.rows) {
const size_t col_end) { uint8_t* slice =
for (size_t row = row_begin; row < row_end; ++row) { p8 + row * bytes_per_row + r.cols.begin() * bytes_per_col;
uint8_t* slice = p8 + row * bytes_per_row + col_begin * bytes_per_col; const size_t slice_size = r.cols.Num() * bytes_per_col;
const size_t slice_size = (col_end - col_begin) * bytes_per_col; BindMemory(slice, slice_size, loc.node);
BindMemory(slice, slice_size, node);
} }
}); });
} }

View File

@ -22,6 +22,7 @@
#include <cstdlib> // std::aligned_alloc / _aligned_malloc #include <cstdlib> // std::aligned_alloc / _aligned_malloc
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "util/basics.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -52,49 +53,6 @@ ByteStorageT AllocateSizeof() {
return hwy::AllocateAligned<uint8_t>(sizeof(T)); return hwy::AllocateAligned<uint8_t>(sizeof(T));
} }
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x len) matrix.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() : batch_size_(0), len_(0) {}
// Main ctor, called from Activations::Allocate.
RowVectorBatch(size_t batch_size, size_t len)
: batch_size_(batch_size), len_(len) {
mem_ = hwy::AllocateAligned<T>(batch_size * len);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return batch_size_; }
size_t Len() const { return len_; }
// Returns the given row vector of length `Len()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < batch_size_);
return mem_.get() + batch_idx * len_;
}
const T* Batch(size_t batch_idx) const {
HWY_DASSERT(batch_idx < batch_size_);
return mem_.get() + batch_idx * len_;
}
// For MatMul or other operations that process the entire batch at once.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
private:
hwy::AlignedFreeUniquePtr<T[]> mem_;
size_t batch_size_; // rows in the matrix
size_t len_; // columns in the matrix = vector length
};
// Stateful in order to know whether to bind to NUMA nodes. `Monostate` for // Stateful in order to know whether to bind to NUMA nodes. `Monostate` for
// convenience - avoids passing around a reference. // convenience - avoids passing around a reference.
class Allocator { class Allocator {
@ -167,10 +125,24 @@ class Allocator {
static size_t alignment_; static size_t alignment_;
}; };
// For shorter arguments to the StaticPartitionRowsAndCols functor.
struct TaskLocation {
TaskLocation(size_t node, size_t package_idx, hwy::ThreadPool& cluster,
size_t worker_offset)
: node(node),
package_idx(package_idx),
cluster(cluster),
worker_offset(worker_offset) {}
size_t node;
size_t package_idx;
hwy::ThreadPool& cluster;
const size_t worker_offset;
};
// Used in MatMul and allocator.h. Defined here because it depends on // Used in MatMul and allocator.h. Defined here because it depends on
// Allocator::Alignment(). // Allocator::Alignment().
template <class Func> template <class Func>
void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols, void StaticPartitionRowsAndCols(NestedPools& nested, Extents2D extents,
size_t bytes_per_element, const Func& func) { size_t bytes_per_element, const Func& func) {
// Both rows and cols must be a multiple of the alignment to avoid // Both rows and cols must be a multiple of the alignment to avoid
// touching remote pages. // touching remote pages.
@ -183,14 +155,15 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
hwy::ThreadPool& all_packages = nested.AllPackages(); hwy::ThreadPool& all_packages = nested.AllPackages();
const size_t num_packages = all_packages.NumWorkers(); const size_t num_packages = all_packages.NumWorkers();
const size_t cols_per_package = const size_t cols_per_package =
hwy::RoundUpTo(hwy::DivCeil(cols, num_packages), multiple); hwy::RoundUpTo(hwy::DivCeil(extents.cols, num_packages), multiple);
const size_t col_tasks = hwy::DivCeil(cols, cols_per_package); const size_t col_tasks = hwy::DivCeil(extents.cols, cols_per_package);
HWY_ASSERT(col_tasks <= num_packages); HWY_ASSERT(col_tasks <= num_packages);
all_packages.Run( all_packages.Run(
0, col_tasks, [&](uint64_t package_idx, size_t package_thread) { 0, col_tasks, [&](uint64_t package_idx, size_t package_thread) {
HWY_ASSERT(package_idx == package_thread); // one task per worker HWY_ASSERT(package_idx == package_thread); // one task per worker
const size_t col_begin = package_idx * cols_per_package; const size_t col_begin = package_idx * cols_per_package;
const size_t col_end = HWY_MIN(col_begin + cols_per_package, cols); const Range1D col_range =
MakeRange1D(col_begin, extents.cols, cols_per_package);
// Static partitioning of rows across the package's clusters. We assume // Static partitioning of rows across the package's clusters. We assume
// that row sharding is cheaper. In MatMul, results can indeed be // that row sharding is cheaper. In MatMul, results can indeed be
@ -198,8 +171,8 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx); hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx);
const size_t num_clusters = all_clusters.NumWorkers(); const size_t num_clusters = all_clusters.NumWorkers();
const size_t rows_per_cluster = const size_t rows_per_cluster =
hwy::RoundUpTo(hwy::DivCeil(rows, num_clusters), multiple); hwy::RoundUpTo(hwy::DivCeil(extents.rows, num_clusters), multiple);
const size_t row_tasks = hwy::DivCeil(rows, rows_per_cluster); const size_t row_tasks = hwy::DivCeil(extents.rows, rows_per_cluster);
HWY_ASSERT(row_tasks <= num_clusters); HWY_ASSERT(row_tasks <= num_clusters);
all_clusters.Run( all_clusters.Run(
0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) { 0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) {
@ -217,11 +190,11 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
nested.WorkerOffset(package_idx, cluster_idx); nested.WorkerOffset(package_idx, cluster_idx);
const size_t row_begin = cluster_idx * rows_per_cluster; const size_t row_begin = cluster_idx * rows_per_cluster;
const size_t row_end = const Range1D row_range =
HWY_MIN(row_begin + rows_per_cluster, rows); MakeRange1D(row_begin, extents.rows, rows_per_cluster);
func(node, cluster, worker_offset, row_begin, row_end, col_begin, func(Range2D(row_range, col_range),
col_end); TaskLocation(node, package_idx, cluster, worker_offset));
}); });
}); });
} }

View File

@ -28,6 +28,7 @@
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" // For CreateGemma #include "gemma/gemma.h" // For CreateGemma
#include "util/args.h" #include "util/args.h"
#include "util/basics.h" // Tristate
#include "util/threading.h" #include "util/threading.h"
#include "hwy/base.h" // HWY_IS_ASAN #include "hwy/base.h" // HWY_IS_ASAN
@ -59,7 +60,9 @@ class AppArgs : public ArgsBase<AppArgs> {
int verbosity; int verbosity;
size_t max_threads; // divided among the detected clusters size_t max_threads; // divided among the detected clusters
int pin; // -1 = auto, 0 = no, 1 = yes Tristate pin; // pin threads?
Tristate spin; // use spin waits?
// For BoundedSlice: // For BoundedSlice:
size_t skip_packages; size_t skip_packages;
size_t max_packages; size_t max_packages;
@ -81,7 +84,10 @@ class AppArgs : public ArgsBase<AppArgs> {
// The exact meaning is more subtle: see the comment at NestedPools ctor. // The exact meaning is more subtle: see the comment at NestedPools ctor.
visitor(max_threads, "num_threads", size_t{0}, visitor(max_threads, "num_threads", size_t{0},
"Maximum number of threads to use; default 0 = unlimited.", 2); "Maximum number of threads to use; default 0 = unlimited.", 2);
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); visitor(pin, "pin", Tristate::kDefault,
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(spin, "spin", Tristate::kDefault,
"Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2);
// These can be used to partition CPU sockets/packages and their // These can be used to partition CPU sockets/packages and their
// clusters/CCXs across several program instances. The default is to use // clusters/CCXs across several program instances. The default is to use
// all available resources. // all available resources.

View File

@ -24,6 +24,7 @@
#include <string> #include <string>
#include "compression/io.h" #include "compression/io.h"
#include "util/basics.h" // Tristate
#include "hwy/base.h" // HWY_ABORT #include "hwy/base.h" // HWY_ABORT
namespace gcpp { namespace gcpp {
@ -62,6 +63,13 @@ class ArgsBase {
} }
} }
void operator()(const Tristate& t, const char* name,
const Tristate& /*init*/, const char* /*help*/,
int print_verbosity = 0) const {
if (verbosity_ >= print_verbosity) {
fprintf(stderr, "%-30s: %s\n", name, ToString(t));
}
}
void operator()(const std::string& t, const char* name, void operator()(const std::string& t, const char* name,
const std::string& /*init*/, const char* /*help*/, const std::string& /*init*/, const char* /*help*/,
int print_verbosity = 0) const { int print_verbosity = 0) const {
@ -127,13 +135,33 @@ class ArgsBase {
return true; return true;
} }
static bool SetValue(const char* string, bool& t) { // Returns lower-cased string. Arg names are expected to be ASCII-only.
static std::string ToLower(const char* string) {
std::string value(string); std::string value(string);
// Lower-case. Arg names are expected to be ASCII-only.
std::transform(value.begin(), value.end(), value.begin(), [](char c) { std::transform(value.begin(), value.end(), value.begin(), [](char c) {
return 'A' <= c && c <= 'Z' ? c - ('Z' - 'z') : c; return 'A' <= c && c <= 'Z' ? c - ('Z' - 'z') : c;
}); });
return value;
}
static bool SetValue(const char* string, Tristate& t) {
const std::string value = ToLower(string);
if (value == "true" || value == "on" || value == "1") {
t = Tristate::kTrue;
return true;
} else if (value == "false" || value == "off" || value == "0") {
t = Tristate::kFalse;
return true;
} else if (value == "default" || value == "auto" || value == "-1") {
t = Tristate::kDefault;
return true;
} else {
return false;
}
}
static bool SetValue(const char* string, bool& t) {
const std::string value = ToLower(string);
if (value == "true" || value == "on" || value == "1") { if (value == "true" || value == "on" || value == "1") {
t = true; t = true;
return true; return true;

View File

@ -20,7 +20,8 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include "hwy/base.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // HWY_IS_MSAN
// IWYU pragma: end_exports // IWYU pragma: end_exports
#if HWY_IS_MSAN #if HWY_IS_MSAN
@ -29,6 +30,19 @@
namespace gcpp { namespace gcpp {
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
static inline const char* ToString(Tristate t) {
switch (t) {
case Tristate::kFalse:
return "false";
case Tristate::kTrue:
return "true";
case Tristate::kDefault:
return "default";
}
}
using BF16 = hwy::bfloat16_t; using BF16 = hwy::bfloat16_t;
static inline void MaybeCheckInitialized(const void* ptr, size_t size) { static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
@ -46,6 +60,195 @@ struct TokenAndProb {
float prob; float prob;
}; };
// Entire size of a 2D array. By contrast, Range2D is a subrange.
struct Extents2D {
Extents2D() : rows(0), cols(0) {}
Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
HWY_DASSERT(rows != 0);
HWY_DASSERT(cols != 0);
}
size_t Area() const { return rows * cols; }
size_t rows;
size_t cols;
};
// Range2D consists of two Range1D.
struct Range1D {
Range1D(size_t begin, size_t end) : begin_(begin), end_(end) {
HWY_DASSERT(begin < end);
}
size_t Num() const { return end_ - begin_; }
// Enable range-based for loops.
class Iterator {
public:
Iterator(size_t i) : i_(i) {}
Iterator& operator++() {
++i_;
return *this;
}
bool operator!=(const Iterator& other) const { return i_ != other.i_; }
size_t operator*() const { return i_; }
// Enable using begin() directly as a size_t.
operator size_t() const { return i_; }
private:
size_t i_;
};
Iterator begin() const { return Iterator(begin_); }
Iterator end() const { return Iterator(end_); }
const size_t begin_;
const size_t end_;
};
static inline Range1D MakeRange1D(size_t begin, size_t end, size_t max_size) {
return Range1D(begin, HWY_MIN(begin + max_size, end));
}
// In MatMul, the two axes are used independently, hence we do not define
// Range2D as a top-left and extents.
struct Range2D {
Range2D(Range1D rows, Range1D cols) : rows(rows), cols(cols) {}
const Range1D rows;
const Range1D cols;
};
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
// it is always float and does not support compressed T, but does support an
// arbitrary stride >= cols.
template <typename T>
class RowPtr {
public:
RowPtr(T* HWY_RESTRICT row0, size_t cols)
: row0_(row0), cols_(cols), stride_(cols) {}
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
size_t Cols() const { return cols_; }
size_t Stride() const { return stride_; }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
}
private:
T* HWY_RESTRICT row0_;
size_t stride_;
size_t cols_;
};
using RowPtrF = RowPtr<float>;
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
// the memory.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() = default;
// Main ctor, called from Activations::Allocate.
RowVectorBatch(Extents2D extents) : extents_(extents) {
mem_ = hwy::AllocateAligned<T>(extents_.rows * extents_.cols);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return extents_.rows; }
size_t Cols() const { return extents_.cols; }
Extents2D Extents() const { return extents_; }
// Returns the given row vector of length `Cols()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * Cols();
}
const T* Batch(size_t batch_idx) const {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * Cols();
}
// For MatMul or other operations that process the entire batch at once.
// TODO: remove once we only use Mat.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return BatchSize() * Cols() * sizeof(T); }
private:
hwy::AlignedFreeUniquePtr<T[]> mem_;
Extents2D extents_;
};
// Used for the A and B arguments of `MatMul`, which are always const.
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
// `ofs` required for compressed T.
template <typename T>
struct ConstMat {
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0)
: ptr(ptr), extents(extents), ofs(ofs) {
HWY_DASSERT(ptr != nullptr);
}
// TODO: support stride for page alignment.
size_t Row(size_t r) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (r >= extents.rows) {
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
}
}
return ofs + extents.cols * r;
}
const Extents2D& Extents() const { return extents; }
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
// subrange of the original rows starting at row 0.
void ShrinkRows(size_t rows) {
HWY_ASSERT(rows <= extents.rows);
extents.rows = rows;
}
const T* HWY_RESTRICT ptr;
Extents2D extents;
// `scale` allows expanding the smaller range of `SfpStream` to the original
// values. MatFromWeights sets this from `MatPtr`.
float scale = 1.0f;
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic.
size_t ofs;
};
// For deducing T.
template <typename T>
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
size_t ofs = 0) {
return ConstMat<T>(ptr, extents, ofs);
}
// For A argument to MatMul (activations).
template <typename T>
ConstMat<T> ConstMatFromBatch(size_t batch_size,
const RowVectorBatch<T>& row_vectors) {
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
Extents2D(batch_size, row_vectors.Cols()));
}
// For C argument to MatMul.
template <typename T>
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(row_vectors.All(), row_vectors.Cols());
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_

400
util/threading.cc Normal file
View File

@ -0,0 +1,400 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "util/threading.h"
#include <stdio.h>
#include <algorithm> // std::sort
#include <atomic>
#include <memory> // std::make_unique
#include <utility> // std::move
#include <vector>
// Placeholder for container detection, do not remove
#include "util/basics.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
namespace gcpp {
// Sort T := packages/clusters by descending 'size' so that users who only use
// one Group get the largest.
template <class T>
static void SortByDescendingSize(std::vector<T>& groups) {
std::sort(groups.begin(), groups.end(),
[](const T& a, const T& b) { return a.Size() > b.Size(); });
}
BoundedTopology::BoundedTopology(BoundedSlice package_slice,
BoundedSlice cluster_slice,
BoundedSlice lp_slice) {
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
LPS enabled_lps;
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
const size_t num_lps = hwy::TotalLogicalProcessors();
fprintf(stderr,
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
num_lps);
for (size_t lp = 0; lp < num_lps; ++lp) {
enabled_lps.Set(lp);
}
}
// Without threading support, only keep the first enabled LP; it might still
// make sense to pin the main thread to avoid migrations.
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
HWY_ASSERT(enabled_lps.Any());
const size_t lp = enabled_lps.First();
enabled_lps = LPS();
enabled_lps.Set(lp);
fprintf(stderr,
"Warning, threads not supported, using only the main thread\n.");
}
#if !GEMMA_DISABLE_TOPOLOGY
if (HWY_LIKELY(!topology_.packages.empty())) {
InitFromTopology(enabled_lps, package_slice, cluster_slice);
}
#endif
// Topology unknown or no packages with enabled LPs: create a single
// package with one cluster, and one node.
if (HWY_UNLIKELY(NumPackages() == 0)) {
InitFromSlice(enabled_lps, lp_slice);
}
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
}
// Topology is unknown, rely on OS affinity and user-specified slice.
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
BoundedSlice lp_slice) {
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
// we honor both the OS affinity and the user-specified slice. Note that
// this can be used to exclude hyperthreads because Linux groups LPs by
// sibling index. For example, the first `num_cores` are not siblings.
const size_t detected = enabled_lps.Count();
size_t enabled_idx = 0;
enabled_lps.Foreach([&](size_t lp) {
if (lp_slice.Contains(detected, enabled_idx++)) {
AddLP(lp);
}
});
// lp_slice can only reduce the number of `enabled_lps`, and not below 1.
HWY_ASSERT(num_workers_ != 0);
}
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
const std::vector<hwy::Topology::LP>& all_lps,
const hwy::Topology::Cluster& tcluster) {
bool is_first_lp = true;
tcluster.lps.Foreach([&](size_t lp) {
// Skip if not first-hyperthread or disabled.
if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
AddLP(lp);
// Set `node` once, and ensure subsequent nodes match - we assume there
// is only one NUMA node per cluster.
const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
if (is_first_lp) {
is_first_lp = false;
node_ = lp_node;
} else {
static bool warned = false;
if (lp_node != node_ && !warned) {
warned = true;
fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n",
lp, lp_node, node_);
}
}
});
}
// NOTE: caller is responsible for checking whether `clusters` is empty.
BoundedTopology::Package::Package(const LPS& enabled_lps,
const hwy::Topology& topology,
size_t package_idx,
BoundedSlice cluster_slice) {
const hwy::Topology::Package& tpackage = topology.packages[package_idx];
// Populate `clusters` with the subset of clusters in `cluster_slice` that
// have any enabled LPs. If `clusters` remains empty, the caller will
// skip this `Package`.
clusters.reserve(cluster_slice.Num(tpackage.clusters.size()));
cluster_slice.Foreach(
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx];
Cluster cluster(enabled_lps, topology.lps, tcluster);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(cluster.Size() != 0)) {
clusters.push_back(std::move(cluster));
}
});
SortByDescendingSize(clusters);
}
#if !GEMMA_DISABLE_TOPOLOGY
static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) {
LPS cores;
lps.Foreach([&](size_t lp) {
if (topology.lps[lp].smt == 0) cores.Set(lp);
});
return cores.Count();
}
// Scans hwy::Topology for clusters and their size, for use by topology_string_.
static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters,
size_t& max_tcluster_cores,
size_t& max_tcluster_lps) {
max_tclusters = 0;
max_tcluster_cores = 0;
max_tcluster_lps = 0;
for (size_t package_idx = 0; package_idx < topology_.packages.size();
++package_idx) {
const std::vector<hwy::Topology::Cluster>& tclusters =
topology_.packages[package_idx].clusters;
max_tclusters = HWY_MAX(max_tclusters, tclusters.size());
size_t tcluster_cores = 0;
size_t tcluster_lps = 0;
for (size_t cluster_idx = 0; cluster_idx < tclusters.size();
++cluster_idx) {
const size_t cores = CoresFromLPs(tclusters[cluster_idx].lps, topology_);
const size_t lps = tclusters[cluster_idx].lps.Count();
tcluster_cores = HWY_MAX(tcluster_cores, cores);
tcluster_lps = HWY_MAX(tcluster_lps, lps);
}
if (tclusters.size() > 1 && tcluster_cores > 8) {
fprintf(stderr,
"Package %zu: multiple clusters with max size %zu, whereas CCX "
"only have 8, may indicate a bug in hwy::Topology.\n",
package_idx, tcluster_cores);
}
max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores);
max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps);
}
HWY_ASSERT(max_tclusters != 0);
HWY_ASSERT(max_tcluster_cores != 0);
HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores);
}
// Main part of ctor, called when topology is known.
void BoundedTopology::InitFromTopology(const LPS& enabled_lps,
BoundedSlice package_slice,
BoundedSlice cluster_slice) {
size_t max_tclusters, max_tcluster_cores, max_tcluster_lps;
ScanTClusters(topology_, max_tclusters, max_tcluster_cores, max_tcluster_lps);
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
package_slice.Foreach(
"package", topology_.packages.size(), [&](size_t package_idx) {
Package package(enabled_lps, topology_, package_idx, cluster_slice);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(!package.clusters.empty())) {
packages_.push_back(std::move(package));
}
});
if (NumPackages() == 0) return;
SortByDescendingSize(packages_);
// Remember NUMA nodes that we are actually using (not just enabled).
for (const Package& p : packages_) {
for (const Cluster& c : p.clusters) {
nodes_.Set(c.Node());
}
}
// Scan for max BoundedTopology clusters and their size, for topology_string_.
size_t all_max_cluster_size = 0;
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
size_t max_cluster_size = 0;
for (size_t cluster_idx = 0; cluster_idx < NumClusters(package_idx);
++cluster_idx) {
max_cluster_size = HWY_MAX(max_cluster_size,
GetCluster(package_idx, cluster_idx).Size());
}
if (NumClusters(package_idx) > 1 && max_cluster_size > 8) {
fprintf(stderr,
"Package %zu: multiple clusters with max size %zu, whereas CCX "
"only have 8, may indicate a bug in BoundedTopology.\n",
package_idx, max_cluster_size);
}
all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size);
}
snprintf(topology_string_, sizeof(topology_string_),
"%zuS %zuX %zuC %zuH, using %zuS %zuX %zuC (nodes=%zu)",
topology_.packages.size(), max_tclusters, max_tcluster_cores,
max_tcluster_lps / max_tcluster_cores, packages_.size(),
NumClusters(0), all_max_cluster_size, nodes_.Count());
}
#endif // !GEMMA_DISABLE_TOPOLOGY
void BoundedTopology::InitFromSlice(const LPS& enabled_lps,
BoundedSlice lp_slice) {
packages_.push_back(Package(enabled_lps, lp_slice));
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
GetCluster(0, 0).Size());
// Assume a single NUMA node.
nodes_.Set(0);
HWY_ASSERT(NumNodes() == 1);
}
static PoolPtr MakePool(size_t num_workers) {
// `ThreadPool` expects the number of threads to create, which is one less
// than the number of workers, but avoid underflow if zero.
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
return std::make_unique<hwy::ThreadPool>(num_threads);
}
static bool InContainer() {
return false;}
class NestedPools::Pinning {
public:
Pinning(Tristate pin, const BoundedTopology& topology) {
if (pin == Tristate::kDefault) {
// Pinning is unreliable inside containers because the hypervisor might
// periodically change our affinity mask, or other processes might also
// pin themselves to the same LPs.
pin = InContainer() ? Tristate::kFalse : Tristate::kTrue;
}
want_pin_ = (pin == Tristate::kTrue);
}
// If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`,
// and sets `any_error_` if any fails.
void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) {
if (HWY_UNLIKELY(!want_pin_)) return;
const std::vector<size_t> lps = cluster.LPVector();
HWY_ASSERT(pool->NumWorkers() <= lps.size());
pool->Run(
0, pool->NumWorkers(),
[this, &pool, &lps](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
fprintf(stderr,
"Pinning failed for task %zu of %zu to %zu (size %zu)\n",
task, pool->NumWorkers(), lps[task], lps.size());
(void)any_error_.test_and_set();
}
});
}
bool WantPin() const { return want_pin_; }
// Called ONCE after all MaybePin because it invalidates the error status.
bool AllPinned() {
// If !want_pin_, MaybePin will return without setting any_error_, but in
// that case we still want to return false to avoid spinning.
// .test() was only added in C++20, so we use .test_and_set() instead.
return want_pin_ && !any_error_.test_and_set();
}
private:
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
bool want_pin_; // set in ctor
}; // Pinning
// Used to divide max_threads and max_workers_per_package across packages and
// clusters. Ensures small upper bounds are respected.
static size_t DivideMaxAcross(const size_t max, const size_t instances) {
// No limit.
if (max == 0) return 0;
// We have enough to distribute.
if (max >= instances) return max / instances;
// Use max as the upper bound for each instance because division would return
// zero, which means 'unlimited'.
return max;
}
NestedPools::NestedPools(size_t max_threads, Tristate pin,
BoundedSlice package_slice, BoundedSlice cluster_slice,
BoundedSlice lp_slice)
: topology_(package_slice, cluster_slice, lp_slice) {
Pinning pinning(pin, topology_);
packages_.resize(topology_.NumPackages());
all_packages_ = MakePool(packages_.size());
const size_t max_workers_per_package =
DivideMaxAcross(max_threads, packages_.size());
// Each worker in all_packages_, including the main thread, will be the
// calling thread of an all_clusters->Run, and hence pinned to one of the
// `cluster.lps` if `pin`.
all_packages_->Run(
0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) {
HWY_ASSERT(package_idx == thread); // each thread has one task
packages_[package_idx] = Package(
topology_, package_idx, max_workers_per_package, pinning, lp_slice);
});
all_pinned_ = pinning.AllPinned();
pin_string_ = all_pinned_ ? "pinned"
: pinning.WantPin() ? "pinning failed"
: "pinning skipped";
// For mapping package/cluster/thread to noncontiguous TLS indices, in case
// cluster/thread counts differ.
HWY_ASSERT(!packages_.empty() && packages_.size() <= 16);
for (const Package& p : packages_) {
max_clusters_per_package_ =
HWY_MAX(max_clusters_per_package_, p.NumClusters());
max_workers_per_cluster_ =
HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster());
}
HWY_ASSERT(max_clusters_per_package_ >= 1);
HWY_ASSERT(max_clusters_per_package_ <= 64);
HWY_ASSERT(max_workers_per_cluster_ >= 1);
HWY_ASSERT(max_workers_per_cluster_ <= 256);
}
// `max_or_zero` == 0 means no limit.
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
}
NestedPools::Package::Package(const BoundedTopology& topology,
size_t package_idx,
size_t max_workers_per_package, Pinning& pinning,
BoundedSlice lp_slice) {
// Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(package_idx));
const size_t max_workers_per_cluster =
DivideMaxAcross(max_workers_per_package, clusters_.size());
all_clusters_ = MakePool(clusters_.size());
// Parallel so we also pin the calling worker in `all_clusters` to
// `cluster.lps`.
all_clusters_->Run(
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& cluster =
topology.GetCluster(package_idx, cluster_idx);
clusters_[cluster_idx] =
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
// Pin workers AND the calling thread from `all_clusters`.
pinning.MaybePin(cluster, clusters_[cluster_idx]);
});
}
} // namespace gcpp

View File

@ -17,13 +17,11 @@
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
#include <stddef.h> #include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::sort
#include <memory> // std::unique_ptr #include <memory> // std::unique_ptr
#include <utility> // std::move
#include <vector> #include <vector>
#include "util/basics.h" // Tristate
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h" #include "hwy/contrib/thread_pool/topology.h"
@ -78,6 +76,10 @@ class BoundedSlice {
// "LP" is a logical processor, a 0-based index passed to the OS. // "LP" is a logical processor, a 0-based index passed to the OS.
using LPS = hwy::LogicalProcessorSet; using LPS = hwy::LogicalProcessorSet;
// We want vectors of hwy::ThreadPool, which is unfortunately not movable,
// hence we wrap them in unique_ptr.
using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
// Wraps hwy::Topology and only keeps the subset of packages and clusters // Wraps hwy::Topology and only keeps the subset of packages and clusters
// apportioned by BoundedSlice, further limited by the OS affinity mask. // apportioned by BoundedSlice, further limited by the OS affinity mask.
// NOTE: if topology is unknown or the OS affinity is too restrictive, we fall // NOTE: if topology is unknown or the OS affinity is too restrictive, we fall
@ -85,96 +87,18 @@ using LPS = hwy::LogicalProcessorSet;
class BoundedTopology { class BoundedTopology {
public: public:
BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice, BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice,
BoundedSlice lp_slice) { BoundedSlice lp_slice);
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
LPS enabled_lps;
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
const size_t num_lps = hwy::TotalLogicalProcessors();
fprintf(
stderr,
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
num_lps);
for (size_t lp = 0; lp < num_lps; ++lp) {
enabled_lps.Set(lp);
}
}
// Without threading support, only keep the first enabled LP; it might still
// make sense to pin the main thread.
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
HWY_ASSERT(enabled_lps.Any());
const size_t lp = enabled_lps.First();
enabled_lps = LPS();
enabled_lps.Set(lp);
}
#if !GEMMA_DISABLE_TOPOLOGY
if (HWY_LIKELY(!topology_.packages.empty())) {
InitFromTopology(enabled_lps, package_slice, cluster_slice);
}
#endif
// Topology unknown, disabled or no packages with enabled LPs: create a
// single package with one cluster, and one node.
if (HWY_UNLIKELY(NumPackages() == 0)) {
InitFromSlice(enabled_lps, lp_slice);
}
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
}
size_t NumPackages() const { return packages_.size(); } size_t NumPackages() const { return packages_.size(); }
const char* TopologyString() const { return topology_string_; }
size_t NumNodes() const { return nodes_.Count(); } size_t NumNodes() const { return nodes_.Count(); }
const char* TopologyString() const { return topology_string_; }
class Cluster { class Cluster {
public: public:
// Topology is unknown, rely on OS affinity and user-specified slice. Cluster(const LPS& enabled_lps, BoundedSlice lp_slice);
Cluster(const LPS& enabled_lps, BoundedSlice lp_slice) {
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
// we honor both the OS affinity and the user-specified slice. Note that
// this can be used to exclude hyperthreads because Linux groups LPs by
// sibling index. For example, the first `num_cores` are not siblings.
const size_t detected = enabled_lps.Count();
size_t enabled_idx = 0;
enabled_lps.Foreach([&](size_t lp) {
if (lp_slice.Contains(detected, enabled_idx++)) {
AddLP(lp);
}
});
// lp_slice can only reduce the number of `enabled_lps`, and not below 1.
HWY_ASSERT(num_workers_ != 0);
}
Cluster(const LPS& enabled_lps, Cluster(const LPS& enabled_lps,
const std::vector<hwy::Topology::LP>& all_lps, const std::vector<hwy::Topology::LP>& all_lps,
const hwy::Topology::Cluster& tcluster) { const hwy::Topology::Cluster& tcluster);
bool is_first_lp = true;
tcluster.lps.Foreach([&](size_t lp) {
// Skip if not first-hyperthread or disabled.
if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
AddLP(lp);
// Set `node` once, and ensure subsequent nodes match - we assume there
// is only one NUMA node per cluster.
const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
if (is_first_lp) {
is_first_lp = false;
node_ = lp_node;
} else {
static bool warned = false;
if (lp_node != node_ && !warned) {
warned = true;
fprintf(stderr,
"WARNING: lp %zu on node %zu != cluster node %zu.\n", lp,
lp_node, node_);
}
}
});
}
// For SortByDescendingSize. // For SortByDescendingSize.
size_t Size() const { return num_workers_; } size_t Size() const { return num_workers_; }
@ -221,53 +145,15 @@ class BoundedTopology {
return package.clusters[cluster_idx]; return package.clusters[cluster_idx];
} }
// Returns total number of cluster workers, for deciding whether to pin.
size_t TotalWorkers() const {
size_t total_workers = 0;
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
const size_t num_clusters = NumClusters(package_idx);
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
total_workers += GetCluster(package_idx, cluster_idx).Size();
}
}
return total_workers;
}
private: private:
// Sort T := packages/clusters by descending 'size' so that users who only use
// one Group get the largest.
template <class T>
static void SortByDescendingSize(std::vector<T>& groups) {
std::sort(groups.begin(), groups.end(),
[](const T& a, const T& b) { return a.Size() > b.Size(); });
}
struct Package { struct Package {
// Topology is unknown, rely on OS affinity and user-specified slice. // Topology is unknown, rely on OS affinity and user-specified slice.
Package(const LPS& enabled_lps, BoundedSlice lp_slice) { Package(const LPS& enabled_lps, BoundedSlice lp_slice) {
clusters.push_back(Cluster(enabled_lps, lp_slice)); clusters.push_back(Cluster(enabled_lps, lp_slice));
} }
// NOTE: caller is responsible for checking whether `clusters` is empty.
Package(const LPS& enabled_lps, const hwy::Topology& topology, Package(const LPS& enabled_lps, const hwy::Topology& topology,
size_t package_idx, BoundedSlice cluster_slice) { size_t package_idx, BoundedSlice cluster_slice);
const hwy::Topology::Package& tpackage = topology.packages[package_idx];
// Populate `clusters` with the subset of clusters in `cluster_slice` that
// have any enabled LPs. If `clusters` remains empty, the caller will
// skip this `Package`.
clusters.reserve(cluster_slice.Num(tpackage.clusters.size()));
cluster_slice.Foreach(
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
const hwy::Topology::Cluster& tcluster =
tpackage.clusters[cluster_idx];
Cluster cluster(enabled_lps, topology.lps, tcluster);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(cluster.Size() != 0)) {
clusters.push_back(std::move(cluster));
}
});
SortByDescendingSize(clusters);
}
// For SortByDescendingSize. // For SortByDescendingSize.
size_t Size() const { return clusters.size(); } size_t Size() const { return clusters.size(); }
@ -275,48 +161,9 @@ class BoundedTopology {
std::vector<Cluster> clusters; std::vector<Cluster> clusters;
}; // Package }; // Package
#if !GEMMA_DISABLE_TOPOLOGY
// Main part of ctor, called when topology is known.
void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice, void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice,
BoundedSlice cluster_slice) { BoundedSlice cluster_slice);
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`. void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice);
package_slice.Foreach(
"package", topology_.packages.size(), [&](size_t package_idx) {
Package package(enabled_lps, topology_, package_idx, cluster_slice);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(!package.clusters.empty())) {
packages_.push_back(std::move(package));
}
});
if (NumPackages() == 0) return;
SortByDescendingSize(packages_);
const hwy::Topology::Package& tpackage0 = topology_.packages[0];
HWY_ASSERT(!tpackage0.clusters.empty());
const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0];
// GetCluster(0, 0) is valid because only non-empty Packages were kept.
snprintf(topology_string_, sizeof(topology_string_),
"%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(),
tpackage0.clusters.size(), tcluster0.lps.Count(), packages_.size(),
NumClusters(0), GetCluster(0, 0).Size());
// Remember NUMA nodes of *enabled* LPs.
enabled_lps.Foreach([&](size_t lp) {
nodes_.Set(static_cast<size_t>(topology_.lps[lp].node));
});
}
#endif
void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice) {
packages_.push_back(Package(enabled_lps, lp_slice));
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
GetCluster(0, 0).Size());
// Assume a single NUMA node.
nodes_.Set(0);
HWY_ASSERT(NumNodes() == 1);
}
#if !GEMMA_DISABLE_TOPOLOGY #if !GEMMA_DISABLE_TOPOLOGY
hwy::Topology topology_; hwy::Topology topology_;
@ -360,51 +207,32 @@ class NestedPools {
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments // would cause huge slowdowns when spinning, the `BoundedSlice` arguments
// only impose upper bounds on the number of detected packages and clusters // only impose upper bounds on the number of detected packages and clusters
// rather than defining the actual number of threads. // rather than defining the actual number of threads.
// NestedPools(size_t max_threads, Tristate pin = Tristate::kDefault,
// `pin` is 0 or 1 to force disable/enable, or -1 to choose automatically.
NestedPools(size_t max_threads, int pin = -1,
BoundedSlice package_slice = BoundedSlice(), BoundedSlice package_slice = BoundedSlice(),
BoundedSlice cluster_slice = BoundedSlice(), BoundedSlice cluster_slice = BoundedSlice(),
BoundedSlice lp_slice = BoundedSlice()) BoundedSlice lp_slice = BoundedSlice());
: topology_(package_slice, cluster_slice, lp_slice) {
if (pin == -1) pin = topology_.TotalWorkers() >= 12;
packages_.resize(topology_.NumPackages()); // Subject to `use_spinning`, enables spin waits with the goal of reducing the
all_packages_ = MakePool(packages_.size()); // latency of barrier synchronization. We only spin during Generate to avoid
const size_t max_workers_per_package = max_threads / packages_.size(); // wasting energy during long waits. If `use_spinning` is kDefault, we first
// Each worker in all_packages_, including the main thread, will be the // set it to kTrue or kFalse based on a heuristic.
// calling thread of an all_clusters->Run, and hence pinned to one of the void MaybeStartSpinning(Tristate& use_spinning) {
// `cluster.lps` if `pin`. if (HWY_UNLIKELY(use_spinning == Tristate::kDefault)) {
all_packages_->Run( // The default is to only spin when pinning was enabled and supported by
0, all_packages_->NumWorkers(), // the OS. Unless spin-waits have near-exclusive use of a core, the tail
[&](uint64_t package_idx, size_t thread) { // latency can be higher than blocking waits.
HWY_ASSERT(package_idx == thread); // each thread has one task use_spinning = all_pinned_ ? Tristate::kTrue : Tristate::kFalse;
packages_[package_idx] = Package( }
topology_, package_idx, max_workers_per_package, pin, lp_slice); if (use_spinning == Tristate::kTrue) {
}); SetWaitMode(hwy::PoolWaitMode::kSpin);
}
// For mapping package/cluster/thread to noncontiguous TLS indices, in case }
// cluster/thread counts differ. void MaybeStopSpinning(const Tristate use_spinning) {
HWY_ASSERT(!packages_.empty() && packages_.size() <= 16); HWY_DASSERT(use_spinning != Tristate::kDefault); // see MaybeStartSpinning
for (const Package& p : packages_) { if (use_spinning == Tristate::kTrue) {
max_clusters_per_package_ = SetWaitMode(hwy::PoolWaitMode::kBlock);
HWY_MAX(max_clusters_per_package_, p.NumClusters());
max_workers_per_cluster_ =
HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster());
} }
HWY_ASSERT(max_clusters_per_package_ >= 1);
HWY_ASSERT(max_clusters_per_package_ <= 64);
HWY_ASSERT(max_workers_per_cluster_ >= 1);
HWY_ASSERT(max_workers_per_cluster_ <= 256);
} }
// Spinning reduces the latency of barrier synchronization, but wastes lots
// of energy for long waits, so only do it during generation. Spinning might
// also be unsafe in virtualized environments because we require threads to
// be running on their own core and thus responsive to the barrier
// synchronization.
void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); }
void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); }
hwy::ThreadPool& AllPackages() { return *all_packages_; } hwy::ThreadPool& AllPackages() { return *all_packages_; }
hwy::ThreadPool& AllClusters(size_t package_idx) { hwy::ThreadPool& AllClusters(size_t package_idx) {
@ -435,7 +263,9 @@ class NestedPools {
// For Allocator // For Allocator
const BoundedTopology& Topology() const { return topology_; } const BoundedTopology& Topology() const { return topology_; }
// For ShowConfig
const char* TopologyString() const { return topology_.TopologyString(); } const char* TopologyString() const { return topology_.TopologyString(); }
const char* PinString() const { return pin_string_; }
// Returns a single pool on the first package: either one thread per cluster // Returns a single pool on the first package: either one thread per cluster
// if there is more than one, which maximizes available memory bandwidth, or // if there is more than one, which maximizes available memory bandwidth, or
@ -449,56 +279,14 @@ class NestedPools {
} }
private: private:
// `max_or_zero` == 0 means no limit. class Pinning;
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
}
// We want vectors of hwy::ThreadPool, which is unfortunately not movable,
// hence we wrap them in unique_ptr.
using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
static PoolPtr MakePool(size_t num_workers) {
// `ThreadPool` expects the number of threads to create, which is one less
// than the number of workers, but avoid underflow if zero.
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
return std::make_unique<hwy::ThreadPool>(num_threads);
}
class Package { class Package {
public: public:
Package() = default; // for vector Package() = default; // for vector
Package(const BoundedTopology& topology, size_t package_idx, Package(const BoundedTopology& topology, size_t package_idx,
size_t max_workers_per_package, int pin, BoundedSlice lp_slice) { size_t max_workers_per_package, Pinning& pinning,
// Pre-allocate because elements are set concurrently. BoundedSlice lp_slice);
clusters_.resize(topology.NumClusters(package_idx));
const size_t max_workers_per_cluster =
max_workers_per_package / clusters_.size();
all_clusters_ = MakePool(clusters_.size());
// Parallel so we also pin the calling worker in `all_clusters` to
// `cluster.lps`.
all_clusters_->Run(
0, all_clusters_->NumWorkers(),
[&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& cluster =
topology.GetCluster(package_idx, cluster_idx);
clusters_[cluster_idx] =
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
if (HWY_LIKELY(pin)) {
// Pin threads AND the calling thread from `all_clusters` to lps.
const std::vector<size_t> lps = cluster.LPVector();
HWY_ASSERT(clusters_[cluster_idx]->NumWorkers() <= lps.size());
clusters_[cluster_idx]->Run(
0, clusters_[cluster_idx]->NumWorkers(),
[&lps](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
hwy::PinThreadToLogicalProcessor(lps[task]);
});
}
});
}
size_t NumClusters() const { return clusters_.size(); } size_t NumClusters() const { return clusters_.size(); }
size_t MaxWorkersPerCluster() const { size_t MaxWorkersPerCluster() const {
@ -536,6 +324,8 @@ class NestedPools {
} }
BoundedTopology topology_; BoundedTopology topology_;
bool all_pinned_;
const char* pin_string_;
std::vector<Package> packages_; std::vector<Package> packages_;
PoolPtr all_packages_; PoolPtr all_packages_;