mirror of https://github.com/google/gemma.cpp.git
Simpler MatMul interface, vocab types, Tristate for use_spinning
Add Extents2D, Range2D vocab types Matmul uses ConstMat for inputs and RowPtr for output Move RowVectorBatch to basics.h Separate threading.cc Fix topology string: report cores not LPs, and #HT Move QStride/IsMHA into LayerConfig ImageTokens does not require make_unique. matmul_test: no longer require template args PiperOrigin-RevId: 692963605
This commit is contained in:
parent
baaa221787
commit
868b01601f
11
BUILD.bazel
11
BUILD.bazel
|
|
@ -30,8 +30,11 @@ cc_library(
|
|||
|
||||
cc_library(
|
||||
name = "threading",
|
||||
srcs = ["util/threading.cc"],
|
||||
hdrs = ["util/threading.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
# Placeholder for container detection, do not remove
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//:topology",
|
||||
|
|
@ -173,7 +176,9 @@ cc_test(
|
|||
tags = ["hwy_ops_test"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":ops",
|
||||
":test_util",
|
||||
":threading",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
|
|
@ -280,6 +285,7 @@ cc_library(
|
|||
":kv_cache",
|
||||
":weights",
|
||||
":threading",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"//paligemma:image",
|
||||
|
|
@ -307,6 +313,7 @@ cc_library(
|
|||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
|
|
@ -317,6 +324,7 @@ cc_library(
|
|||
hdrs = ["util/app.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":threading",
|
||||
|
|
@ -342,8 +350,6 @@ cc_library(
|
|||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//:topology",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -583,6 +589,7 @@ cc_test(
|
|||
},
|
||||
deps = [
|
||||
":backprop",
|
||||
":basics",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":optimizer",
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@ set(SOURCES
|
|||
util/args.h
|
||||
util/basics.h
|
||||
util/test_util.h
|
||||
util/threading.cc
|
||||
util/threading.h
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,13 +33,15 @@
|
|||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
TEST(OptimizeTest, GradientDescent) {
|
||||
NestedPools pools(1, /*pin=*/0, BoundedSlice(0, 1), BoundedSlice(0, 1));
|
||||
NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
|
||||
BoundedSlice(0, 1));
|
||||
hwy::ThreadPool& pool = pools.Pool();
|
||||
std::mt19937 gen(42);
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@
|
|||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h"
|
||||
#include "compression/shared.h"
|
||||
#include "util/basics.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "util/allocator.h"
|
||||
#if COMPRESS_STATS
|
||||
|
|
@ -62,7 +63,9 @@ class MatPtr {
|
|||
num_elements_(rows * cols),
|
||||
rows_(rows),
|
||||
cols_(cols),
|
||||
ptr_(nullptr) {}
|
||||
ptr_(nullptr) {
|
||||
stride_ = cols;
|
||||
}
|
||||
// Default is to leave all fields default-initialized.
|
||||
MatPtr() = default;
|
||||
virtual ~MatPtr();
|
||||
|
|
@ -85,7 +88,9 @@ class MatPtr {
|
|||
element_size_(key2.hi),
|
||||
num_elements_(key2.lo),
|
||||
rows_(key3.lo),
|
||||
cols_(key3.hi) {}
|
||||
cols_(key3.hi) {
|
||||
stride_ = cols_;
|
||||
}
|
||||
|
||||
// Adds the contents entry to the table of contents.
|
||||
void AddToToc(std::vector<hwy::uint128_t>& toc) const {
|
||||
|
|
@ -137,6 +142,12 @@ class MatPtr {
|
|||
// Returns the number of columns in the 2-d array (inner dimension).
|
||||
size_t Cols() const { return cols_; }
|
||||
|
||||
Extents2D Extents() const { return Extents2D(rows_, cols_); }
|
||||
|
||||
// Currently same as cols, but may differ in the future. This is the offset by
|
||||
// which to advance pointers to the next row.
|
||||
size_t Stride() const { return stride_; }
|
||||
|
||||
// Decoded elements should be multiplied by this to restore their original
|
||||
// range. This is required because SfpStream can only encode a limited range
|
||||
// of magnitudes.
|
||||
|
|
@ -187,6 +198,8 @@ class MatPtr {
|
|||
// freed. The underlying memory is owned by a subclass or some external class
|
||||
// and must outlive this object.
|
||||
void* ptr_ = nullptr;
|
||||
|
||||
size_t stride_;
|
||||
};
|
||||
|
||||
// MatPtrT adds a single template argument to MatPtr for an explicit type.
|
||||
|
|
@ -288,7 +301,15 @@ decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
|
||||
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
|
||||
mat.scale = m.scale();
|
||||
return mat;
|
||||
}
|
||||
|
||||
// MatStorageT adds the actual data storage to MatPtrT.
|
||||
// TODO: use Extents2D instead of rows and cols.
|
||||
template <typename MatT>
|
||||
class MatStorageT : public MatPtrT<MatT> {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -267,8 +267,12 @@ struct PackedSpan {
|
|||
// check the compressed count and ensure we have that many.
|
||||
const size_t required =
|
||||
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
|
||||
HWY_DASSERT(num >= required);
|
||||
(void)required;
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (num < required) {
|
||||
HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed",
|
||||
packed_ofs, num_accessible, required, num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Packed* HWY_RESTRICT ptr;
|
||||
|
|
|
|||
|
|
@ -229,12 +229,12 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
|||
fprintf(stderr,
|
||||
"Date & Time : %s" // dt includes \n
|
||||
"CPU : %s\n"
|
||||
"CPU topology : %s\n"
|
||||
"CPU topology : %s, %s\n"
|
||||
"Instruction set : %s (%zu bits)\n"
|
||||
"Compiled config : %s\n"
|
||||
"Weight Type : %s\n"
|
||||
"EmbedderInput Type : %s\n",
|
||||
dt, cpu100, pools.TopologyString(),
|
||||
dt, cpu100, pools.TopologyString(), pools.PinString(),
|
||||
hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8,
|
||||
CompiledConfig(), StringFromType(loader.Info().weight),
|
||||
TypeName<EmbedderInputT>());
|
||||
|
|
|
|||
|
|
@ -72,18 +72,11 @@ struct Activations {
|
|||
size_t seq_len;
|
||||
size_t cache_pos_size = 0;
|
||||
|
||||
// Multi-Head Attention?
|
||||
bool IsMHA() const { return layer_config.heads == layer_config.kv_heads; }
|
||||
|
||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
size_t QStride() const { return layer_config.qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim,
|
||||
PostQKType post_qk) {
|
||||
const size_t rope_dim =
|
||||
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
|
||||
RowVectorBatch<float> inv_timescale(1, rope_dim / 2);
|
||||
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
|
||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||
const float freq_exponents =
|
||||
static_cast<float>(2 * dim) / static_cast<float>(rope_dim);
|
||||
|
|
@ -100,29 +93,31 @@ struct Activations {
|
|||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
const size_t vocab_size = weights_config.vocab_size;
|
||||
|
||||
x = RowVectorBatch<float>(batch_size, model_dim);
|
||||
q = RowVectorBatch<float>(batch_size, layer_config.heads * QStride());
|
||||
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
q = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.QStride()));
|
||||
if (vocab_size > 0) {
|
||||
logits = RowVectorBatch<float>(batch_size, vocab_size);
|
||||
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
|
||||
}
|
||||
|
||||
pre_att_rms_out = RowVectorBatch<float>(batch_size, model_dim);
|
||||
att = RowVectorBatch<float>(batch_size,
|
||||
layer_config.heads * weights_config.seq_len);
|
||||
att_out = RowVectorBatch<float>(batch_size,
|
||||
layer_config.heads * layer_config.qkv_dim);
|
||||
att_sums = RowVectorBatch<float>(batch_size, model_dim);
|
||||
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
att = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, layer_config.heads * weights_config.seq_len));
|
||||
att_out = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim));
|
||||
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
|
||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, model_dim);
|
||||
C1 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
|
||||
C2 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
|
||||
ffw_out = RowVectorBatch<float>(batch_size, model_dim);
|
||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
|
||||
C1 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
||||
C2 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
||||
ffw_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
|
||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
||||
griffin_x = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_y = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_gate_x = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_multiplier = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_y = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_gate_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_multiplier =
|
||||
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
}
|
||||
|
||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
|
||||
|
|
|
|||
|
|
@ -119,6 +119,13 @@ enum class Model {
|
|||
struct LayerConfig {
|
||||
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
|
||||
|
||||
// Multi-Head Attention?
|
||||
bool IsMHA() const { return heads == kv_heads; }
|
||||
|
||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
size_t model_dim = 0;
|
||||
size_t griffin_dim = 0;
|
||||
size_t ff_hidden_dim = 0;
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
|
|
@ -31,6 +31,7 @@
|
|||
// Placeholder for internal test4, do not remove
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -232,49 +233,49 @@ class GemmaAttention {
|
|||
// KV directly to KVCache.
|
||||
HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
// For the computation of Q, K, and V, it is useful to remember that
|
||||
// qkv_einsum_w has shape [(layer_config_.heads + layer_config_.kv_heads *
|
||||
// 2), kKQVDim, layer_config_.model_dim] and q_stride_ =
|
||||
// layer_config_.qkv_dim * (is_mha_ ? 3 : 1);
|
||||
const size_t model_dim = layer_config_.model_dim;
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
const size_t kv_heads = layer_config_.kv_heads;
|
||||
|
||||
const auto pre_att_rms_out =
|
||||
ConstMat(activations_.pre_att_rms_out.All(), layer_config_.model_dim);
|
||||
const auto w_q1 = layer_weights_.qkv_einsum_w.data() == nullptr
|
||||
? ConstMat(layer_weights_.qkv_einsum_w1.data(),
|
||||
layer_config_.model_dim)
|
||||
: ConstMat(layer_weights_.qkv_einsum_w.data(),
|
||||
layer_config_.model_dim);
|
||||
const auto w_q2 =
|
||||
layer_weights_.qkv_einsum_w.data() == nullptr
|
||||
? ConstMat(layer_weights_.qkv_einsum_w2.data(),
|
||||
layer_config_.model_dim)
|
||||
: ConstMat(layer_weights_.qkv_einsum_w.data(),
|
||||
layer_config_.model_dim, layer_config_.model_dim,
|
||||
layer_config_.heads * layer_config_.qkv_dim *
|
||||
layer_config_.model_dim);
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_interleaved, pre_att_rms_out, w_q1,
|
||||
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, activations_.env,
|
||||
MutableMat(activations_.q.All(), layer_config_.heads * q_stride_));
|
||||
ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out);
|
||||
auto w_q1 = layer_weights_.qkv_einsum_w.data()
|
||||
? ConstMatFromWeights(layer_weights_.qkv_einsum_w)
|
||||
: ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
|
||||
// We must shrink to the actual size because MatMul verifies
|
||||
// `B.extents.rows == C.Cols()`. If MHA, `QStride() == 3 * qkv_dim` and all
|
||||
// rows are used. Otherwise, `QStride() == qkv_dim` and KV will be
|
||||
// computed in the second MatMul.
|
||||
const size_t w1_rows = heads * layer_config_.QStride();
|
||||
w_q1.ShrinkRows(w1_rows);
|
||||
MatMul(pre_att_rms_out, w_q1,
|
||||
/*add=*/nullptr, activations_.env, RowPtrFromBatch(activations_.q));
|
||||
|
||||
if (is_mha_) {
|
||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||
} else {
|
||||
auto w_q2 = layer_weights_.qkv_einsum_w.data()
|
||||
? ConstMatFromWeights(layer_weights_.qkv_einsum_w,
|
||||
w1_rows * model_dim)
|
||||
: ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
|
||||
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
|
||||
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
|
||||
w_q2.ShrinkRows(w_rows_kv_cols);
|
||||
|
||||
// Single query and no wraparound means we can use a matmul and write
|
||||
// directly into the KV cache with a stride of cache_pos_size_.
|
||||
if (num_queries_ == 1 &&
|
||||
queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) {
|
||||
const size_t kv_ofs =
|
||||
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
// KV structure is [k, v, k, v, ....] = layer_config_.kv_heads pairs of
|
||||
// (k, v).
|
||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_tokens_, pre_att_rms_out, w_q2,
|
||||
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
|
||||
activations_.env,
|
||||
MutableMat(kv, layer_config_.kv_heads * 2 * layer_config_.qkv_dim,
|
||||
cache_pos_size_));
|
||||
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
||||
kv_rows.SetStride(cache_pos_size_);
|
||||
MatMul(pre_att_rms_out, w_q2,
|
||||
/*add=*/nullptr, activations_.env, kv_rows);
|
||||
} else {
|
||||
// Proceed row by row because there will be wraparound.
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
|
|
@ -288,40 +289,34 @@ class GemmaAttention {
|
|||
const size_t kv_offset =
|
||||
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
// KV structure is [k, v, k, v, ....] = layer_config_.kv_heads pairs
|
||||
// of (k, v).
|
||||
if (layer_weights_.qkv_einsum_w.data() == nullptr) {
|
||||
MatVec(layer_weights_.qkv_einsum_w2, 0,
|
||||
layer_config_.kv_heads * 2 * layer_config_.qkv_dim,
|
||||
layer_config_.model_dim, x, kv, pool_);
|
||||
if (layer_weights_.qkv_einsum_w.data()) {
|
||||
MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim,
|
||||
w_rows_kv_cols, model_dim, x, kv, pool_);
|
||||
} else {
|
||||
MatVec(layer_weights_.qkv_einsum_w,
|
||||
layer_config_.heads * layer_config_.qkv_dim *
|
||||
layer_config_.model_dim,
|
||||
layer_config_.kv_heads * 2 * layer_config_.qkv_dim,
|
||||
layer_config_.model_dim, x, kv, pool_);
|
||||
MatVec(layer_weights_.qkv_einsum_w2, 0, //
|
||||
w_rows_kv_cols, model_dim, x, kv, pool_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // !is_mha_
|
||||
|
||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||
pool_.Run(0, layer_config_.kv_heads * num_interleaved,
|
||||
pool_.Run(0, kv_heads * num_interleaved,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % layer_config_.kv_heads;
|
||||
const size_t interleaved_idx = task / layer_config_.kv_heads;
|
||||
const size_t head = task % kv_heads;
|
||||
const size_t interleaved_idx = task / kv_heads;
|
||||
const size_t query_idx = interleaved_idx % num_queries_;
|
||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||
layer_ * cache_layer_size_ +
|
||||
head * layer_config_.qkv_dim * 2;
|
||||
head * qkv_dim * 2;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
const float* HWY_RESTRICT mha_kv =
|
||||
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
||||
layer_config_.qkv_dim;
|
||||
qkv_dim;
|
||||
|
||||
// Copy from `q` if MHA, or apply in-place.
|
||||
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
|
||||
|
|
@ -329,9 +324,8 @@ class GemmaAttention {
|
|||
|
||||
// If MHA, also copy V into KVCache.
|
||||
if (is_mha_) {
|
||||
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
|
||||
kv + layer_config_.qkv_dim,
|
||||
layer_config_.qkv_dim * sizeof(*kv));
|
||||
hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim,
|
||||
qkv_dim * sizeof(*kv));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -463,27 +457,14 @@ class GemmaAttention {
|
|||
HWY_DASSERT(layer_weights_.att_weights.data() != nullptr);
|
||||
HWY_DASSERT(activations_.att_out.All() != nullptr);
|
||||
HWY_DASSERT(activations_.att_sums.All() != nullptr);
|
||||
if (layer_weights_.layer_config.softmax_attn_output_biases) {
|
||||
MatMul</*kAdd=*/true>(
|
||||
num_interleaved,
|
||||
ConstMat(activations_.att_out.All(),
|
||||
layer_config_.heads * layer_config_.qkv_dim),
|
||||
ConstMat(layer_weights_.att_weights.data(),
|
||||
layer_config_.heads * layer_config_.qkv_dim),
|
||||
layer_weights_.att_weights.scale(),
|
||||
layer_weights_.attention_output_biases.data_scale1(),
|
||||
activations_.env,
|
||||
MutableMat(activations_.att_sums.All(), layer_config_.model_dim));
|
||||
} else {
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_interleaved,
|
||||
ConstMat(activations_.att_out.All(),
|
||||
layer_config_.heads * layer_config_.qkv_dim),
|
||||
ConstMat(layer_weights_.att_weights.data(),
|
||||
layer_config_.heads * layer_config_.qkv_dim),
|
||||
layer_weights_.att_weights.scale(), nullptr, activations_.env,
|
||||
MutableMat(activations_.att_sums.All(), layer_config_.model_dim));
|
||||
}
|
||||
|
||||
const float* add =
|
||||
layer_weights_.layer_config.softmax_attn_output_biases
|
||||
? layer_weights_.attention_output_biases.data_scale1()
|
||||
: nullptr;
|
||||
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
|
||||
ConstMatFromWeights(layer_weights_.att_weights), add,
|
||||
activations_.env, RowPtrFromBatch(activations_.att_sums));
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -524,13 +505,13 @@ class GemmaAttention {
|
|||
num_queries_(queries_pos.size()),
|
||||
num_tokens_(num_tokens),
|
||||
layer_(layer),
|
||||
q_stride_(activations.QStride()),
|
||||
layer_config_(layer_weights->layer_config),
|
||||
q_stride_(layer_config_.QStride()),
|
||||
cache_layer_size_(layer_weights->layer_config.CacheLayerSize()),
|
||||
cache_pos_size_(activations.cache_pos_size),
|
||||
is_mha_(activations.IsMHA()),
|
||||
is_mha_(layer_config_.IsMHA()),
|
||||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
layer_config_(layer_weights->layer_config),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
pool_(activations.env.Pool()) {
|
||||
|
|
@ -552,6 +533,7 @@ class GemmaAttention {
|
|||
const size_t num_queries_;
|
||||
const size_t num_tokens_;
|
||||
const size_t layer_;
|
||||
const LayerConfig& layer_config_;
|
||||
const size_t q_stride_ = 0;
|
||||
const size_t cache_layer_size_ = 0;
|
||||
const size_t cache_pos_size_ = 0;
|
||||
|
|
@ -559,7 +541,6 @@ class GemmaAttention {
|
|||
|
||||
Activations& activations_;
|
||||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const LayerConfig& layer_config_;
|
||||
const hwy::Divisor& div_seq_len_;
|
||||
const KVCaches& kv_caches_;
|
||||
hwy::ThreadPool& pool_;
|
||||
|
|
@ -601,17 +582,13 @@ class VitAttention {
|
|||
// Computes Q, K, V for all heads, stored in activations_.q.
|
||||
HWY_NOINLINE void ComputeQKV() {
|
||||
PROFILER_ZONE("Gen.VitAttention.QKV");
|
||||
const auto y =
|
||||
ConstMat(activations_.pre_att_rms_out.All(), layer_config_.model_dim);
|
||||
auto& qkv = activations_.q;
|
||||
HWY_ASSERT(qkv.BatchSize() == num_tokens_);
|
||||
HWY_ASSERT(qkv.Len() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
MatMul</*kAdd=*/true>(
|
||||
num_tokens_, y,
|
||||
ConstMat(layer_weights_.vit.qkv_einsum_w.data_scale1(),
|
||||
layer_config_.model_dim),
|
||||
/*scale=*/1.0f, layer_weights_.vit.qkv_einsum_b.data_scale1(),
|
||||
activations_.env, MutableMat(qkv.All(), qkv.Len()));
|
||||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
|
||||
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
|
||||
layer_weights_.vit.qkv_einsum_b.data_scale1(), activations_.env,
|
||||
RowPtrFromBatch(qkv));
|
||||
}
|
||||
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
||||
|
|
@ -658,17 +635,13 @@ class VitAttention {
|
|||
HWY_NOINLINE void SumHeads() {
|
||||
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
||||
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
|
||||
auto att_out = ConstMat(activations_.att_out.All(),
|
||||
layer_config_.heads * layer_config_.qkv_dim);
|
||||
auto att_weights = ConstMat(layer_weights_.vit.attn_out_w.data_scale1(),
|
||||
layer_config_.heads * layer_config_.qkv_dim);
|
||||
auto att_sums =
|
||||
MutableMat(activations_.att_sums.All(), layer_config_.model_dim);
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
MatMul</*kAdd=*/true>(num_tokens_, att_out, att_weights, /*scale=*/1.0f,
|
||||
bias, activations_.env, att_sums);
|
||||
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
||||
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
||||
auto att_sums = RowPtrFromBatch(activations_.att_sums);
|
||||
MatMul(att_out, att_weights, bias, activations_.env, att_sums);
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -720,125 +693,94 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
|||
PROFILER_ZONE("Gen.FFW");
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
using WeightType = T;
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||
Mat<const WeightType> w1;
|
||||
const float* bias1 = nullptr;
|
||||
Mat<const WeightType> w2;
|
||||
const float* bias2 = nullptr;
|
||||
float scale = 1.0f;
|
||||
Mat<const WeightType> w_output;
|
||||
const float* output_bias = nullptr;
|
||||
float output_scale = 1.0f;
|
||||
auto hidden_activations = MutableMat(activations.C1.All(), ffh_hidden_dim);
|
||||
auto multiplier = MutableMat(activations.C2.All(), ffh_hidden_dim);
|
||||
auto ffw_out = MutableMat(activations.ffw_out.All(), model_dim);
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
const float* bias1 =
|
||||
add_bias ? layer_weights->ffw_gating_biases.data_scale1() : nullptr;
|
||||
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
|
||||
const float* output_bias =
|
||||
add_bias ? layer_weights->ffw_output_biases.data_scale1() : nullptr;
|
||||
|
||||
// For some of the weights and activations, it depends on the config where to
|
||||
// get them from or whether to use them at all.
|
||||
bias1 = layer_weights->ffw_gating_biases.data_scale1();
|
||||
bias2 = bias1 + ffh_hidden_dim;
|
||||
output_bias = layer_weights->ffw_output_biases.data_scale1();
|
||||
w1 = layer_weights->gating_einsum_w.data() == nullptr
|
||||
? ConstMat(layer_weights->gating_einsum_w1.data(), model_dim)
|
||||
: ConstMat(layer_weights->gating_einsum_w.data(), model_dim);
|
||||
w2 = layer_weights->gating_einsum_w.data() == nullptr
|
||||
? ConstMat(layer_weights->gating_einsum_w2.data(), model_dim)
|
||||
: ConstMat(layer_weights->gating_einsum_w.data(), model_dim,
|
||||
model_dim, model_dim * ffh_hidden_dim);
|
||||
scale = layer_weights->gating_einsum_w.data() == nullptr
|
||||
? layer_weights->gating_einsum_w1.scale()
|
||||
: layer_weights->gating_einsum_w.scale();
|
||||
w_output = ConstMat(layer_weights->linear_w.data(), ffh_hidden_dim);
|
||||
output_scale = layer_weights->linear_w.scale();
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
auto hidden_activations = RowPtrFromBatch(activations.C1);
|
||||
auto multiplier = RowPtrFromBatch(activations.C2);
|
||||
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
|
||||
|
||||
// gating_einsum_w holds two half-matrices. We plan to change the importer to
|
||||
// avoid this confusion by splitting into gating_einsum_w1 and
|
||||
// gating_einsum_w2.
|
||||
const bool split = !!layer_weights->gating_einsum_w.data();
|
||||
auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w)
|
||||
: ConstMatFromWeights(layer_weights->gating_einsum_w1);
|
||||
auto w2 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w,
|
||||
model_dim * ffh_hidden_dim)
|
||||
: ConstMatFromWeights(layer_weights->gating_einsum_w2);
|
||||
if (split) {
|
||||
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
|
||||
w1.ShrinkRows(ffh_hidden_dim);
|
||||
w2.ShrinkRows(ffh_hidden_dim);
|
||||
}
|
||||
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
if (add_bias) {
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, x, w1, scale, bias1,
|
||||
activations.env, hidden_activations);
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, x, w2, scale, bias2,
|
||||
activations.env, multiplier);
|
||||
} else {
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, x, w1, scale, bias1,
|
||||
activations.env, hidden_activations);
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, x, w2, scale, bias2,
|
||||
activations.env, multiplier);
|
||||
}
|
||||
MatMul(x, w1, bias1, activations.env, hidden_activations);
|
||||
MatMul(x, w2, bias2, activations.env, multiplier);
|
||||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.ptr,
|
||||
multiplier.ptr, ffh_hidden_dim * num_interleaved);
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
||||
multiplier.Row(0), ffh_hidden_dim * num_interleaved);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
if (add_bias) {
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, ConstMat(hidden_activations),
|
||||
w_output, output_scale, output_bias,
|
||||
activations.env, ffw_out);
|
||||
} else {
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, ConstMat(hidden_activations),
|
||||
w_output, output_scale, output_bias,
|
||||
activations.env, ffw_out);
|
||||
}
|
||||
auto activations_mat = MakeConstMat(
|
||||
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));
|
||||
|
||||
MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
|
||||
}
|
||||
|
||||
// Same as FFWNoVit, but with different layer_weights members and no second
|
||||
// gating matrix.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
||||
const LayerWeightsPtrs<T>* layer_weights) {
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
using WeightType = typename LayerWeightsPtrs<T>::WeightF32OrBF16;
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||
Mat<const WeightType> w1;
|
||||
const float* bias1 = nullptr;
|
||||
float scale = 1.0f;
|
||||
Mat<const WeightType> w_output;
|
||||
const float* output_bias = nullptr;
|
||||
float output_scale = 1.0f;
|
||||
auto hidden_activations = MutableMat(activations.C1.All(), ff_hidden_dim);
|
||||
auto multiplier = MutableMat(activations.C2.All(), ff_hidden_dim);
|
||||
auto ffw_out = MutableMat(activations.ffw_out.All(), model_dim);
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
const float* bias1 =
|
||||
add_bias ? layer_weights->vit.linear_0_b.data_scale1() : nullptr;
|
||||
const float* output_bias =
|
||||
add_bias ? layer_weights->vit.linear_1_b.data_scale1() : nullptr;
|
||||
|
||||
// For some of the weights and activations, it depends on the config where to
|
||||
// get them from or whether to use them at all.
|
||||
w1 = ConstMat(layer_weights->vit.linear_0_w.data_scale1(), model_dim);
|
||||
bias1 = layer_weights->vit.linear_0_b.data_scale1();
|
||||
multiplier.ptr = nullptr;
|
||||
w_output =
|
||||
ConstMat(layer_weights->vit.linear_1_w.data_scale1(), ff_hidden_dim);
|
||||
output_bias = layer_weights->vit.linear_1_b.data_scale1();
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
auto hidden_activations = RowPtrFromBatch(activations.C1);
|
||||
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
|
||||
|
||||
auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
|
||||
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
if (add_bias) {
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, x, w1, scale, bias1,
|
||||
activations.env, hidden_activations);
|
||||
} else {
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, x, w1, scale, bias1,
|
||||
activations.env, hidden_activations);
|
||||
}
|
||||
MatMul(x, w1, bias1, activations.env, hidden_activations);
|
||||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.ptr,
|
||||
multiplier.ptr, ff_hidden_dim * num_interleaved);
|
||||
// Activation (Gelu), store in act.
|
||||
RowPtrF multiplier = RowPtrF(nullptr, 0);
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
||||
multiplier.Row(0), ff_hidden_dim * num_interleaved);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
if (add_bias) {
|
||||
MatMul</*kAddBias=*/true>(num_interleaved, ConstMat(hidden_activations),
|
||||
w_output, output_scale, output_bias,
|
||||
activations.env, ffw_out);
|
||||
} else {
|
||||
MatMul</*kAddBias=*/false>(num_interleaved, ConstMat(hidden_activations),
|
||||
w_output, output_scale, output_bias,
|
||||
activations.env, ffw_out);
|
||||
}
|
||||
auto activations_mat = MakeConstMat(
|
||||
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));
|
||||
|
||||
MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
|
|
@ -853,7 +795,7 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
|||
// Image tokens just need to be copied.
|
||||
if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
|
||||
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
|
||||
x.Len() * sizeof(x.Const()[0]));
|
||||
x.Cols() * sizeof(x.Const()[0]));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -942,7 +884,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
|
|||
// the Big Vision codebase. See
|
||||
// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py
|
||||
// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and
|
||||
// try mergig this with TransformerLayer.
|
||||
// try merging this with TransformerLayer.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
|
|
@ -953,7 +895,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
|||
|
||||
auto& x = activations.x;
|
||||
HWY_DASSERT(x.BatchSize() == num_tokens);
|
||||
HWY_DASSERT(x.Len() == model_dim);
|
||||
HWY_DASSERT(x.Cols() == model_dim);
|
||||
|
||||
// y = nn.LayerNorm()(x)
|
||||
// y ~ pre_att_rms_out
|
||||
|
|
@ -1106,7 +1048,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
const size_t patch_size = patch_width * patch_width * 3;
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
|
||||
patch_size * model_dim);
|
||||
HWY_DASSERT(activations.x.Len() == model_dim);
|
||||
HWY_DASSERT(activations.x.Cols() == model_dim);
|
||||
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
image_patches[i] = hwy::AllocateAligned<float>(patch_size);
|
||||
|
|
@ -1118,11 +1060,11 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// This could be done as one MatMul like:
|
||||
// RowVectorBatch<float> image_patches(kSeqLen, kPatchSize);
|
||||
// [Get patches]
|
||||
// MatMul</*kAdd=*/true>(
|
||||
// kVitSeqLen, ConstMat(image_patches.All(), kPatchSize),
|
||||
// ConstMat(weights.vit_img_embedding_kernel.data_scale1(), kPatchSize),
|
||||
// /*scale=*/1.0f, weights.vit_img_embedding_bias.data_scale1(),
|
||||
// activations.env, MutableMat(activations.x.All(), kVitModelDim));
|
||||
// MatMul(
|
||||
// MatFromBatch(kVitSeqLen, image_patches),
|
||||
// MatFromWeights(weights.vit_img_embedding_kernel),
|
||||
// weights.vit_img_embedding_bias.data_scale1(), activations.env,
|
||||
// RowPtrF(activations.x.All(), kVitModelDim));
|
||||
// However, MatMul currently requires that
|
||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||
// which is not the case here. We should relax that requirement on MatMul and
|
||||
|
|
@ -1163,11 +1105,10 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
activations.x.All(), vit_model_dim);
|
||||
|
||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMul</*kAdd=*/true>(
|
||||
num_tokens, ConstMat(activations.x.All(), vit_model_dim),
|
||||
ConstMat(weights.vit_img_head_kernel.data_scale1(), vit_model_dim),
|
||||
/*scale=*/1.0f, weights.vit_img_head_bias.data_scale1(), activations.env,
|
||||
MutableMat(image_tokens.All(), weights.weights_config.model_dim));
|
||||
MatMul(ConstMatFromBatch(num_tokens, activations.x),
|
||||
ConstMatFromWeights(weights.vit_img_head_kernel),
|
||||
weights.vit_img_head_bias.data_scale1(), activations.env,
|
||||
RowPtrFromBatch(image_tokens));
|
||||
}
|
||||
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
|
|
@ -1299,7 +1240,6 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
const QueriesPos& queries_prefix_end,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t model_dim = model.Config().model_dim;
|
||||
const size_t vocab_size = model.Config().vocab_size;
|
||||
const ModelWeightsPtrs<T>& weights = *model.GetWeightsOfType<T>();
|
||||
|
||||
|
|
@ -1387,11 +1327,10 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
{
|
||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_queries, ConstMat(activations.x.All(), model_dim),
|
||||
ConstMat(weights.embedder_input_embedding.data(), model_dim),
|
||||
weights.embedder_input_embedding.scale(), /*add=*/nullptr,
|
||||
activations.env, MutableMat(activations.logits.All(), vocab_size));
|
||||
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
||||
ConstMatFromWeights(weights.embedder_input_embedding),
|
||||
/*add=*/nullptr, activations.env,
|
||||
RowPtrFromBatch(activations.logits));
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@
|
|||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -119,12 +118,12 @@ struct GenerateImageTokensT {
|
|||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
pools_.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateSingleT>(
|
||||
runtime_config, prompt, pos, prefix_end, kv_cache, pools_, timing_info);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
pools_.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
|
|
@ -141,23 +140,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||
}
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
pools_.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateBatchT>(
|
||||
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
|
||||
kv_caches, pools_, timing_info);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
pools_.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens) {
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
pools_.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
|
||||
image_tokens, pools_);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
pools_.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
|
||||
|
|
|
|||
|
|
@ -121,7 +121,11 @@ struct RuntimeConfig {
|
|||
const ImageTokens *image_tokens = nullptr;
|
||||
|
||||
// Whether to use thread spinning to reduce barrier synchronization latency.
|
||||
bool use_spinning = true;
|
||||
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
|
||||
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
|
||||
// default decision is likely sufficient because it is based on whether
|
||||
// threads are successfully pinned.
|
||||
mutable Tristate use_spinning = Tristate::kDefault;
|
||||
|
||||
// End-of-sequence token.
|
||||
int eos_id = EOS_ID;
|
||||
|
|
|
|||
44
gemma/run.cc
44
gemma/run.cc
|
|
@ -16,7 +16,6 @@
|
|||
// Command line text interface to gemma.
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
|
@ -79,8 +78,8 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
|||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
||||
int verbosity, const AcceptFunc& accept_token,
|
||||
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||
const InferenceArgs& args, const AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // across turns
|
||||
|
|
@ -92,17 +91,18 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
|
||||
const bool have_image = !args.image_file.path.empty();
|
||||
Image image;
|
||||
std::unique_ptr<ImageTokens> image_tokens;
|
||||
ImageTokens image_tokens;
|
||||
if (have_image) {
|
||||
image_tokens = std::make_unique<ImageTokens>(
|
||||
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
|
||||
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen};
|
||||
RuntimeConfig runtime_config = {
|
||||
.verbosity = app.verbosity, .gen = &gen, .use_spinning = app.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
model.GenerateImageTokens(runtime_config, image, *image_tokens);
|
||||
if (verbosity >= 1) {
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
if (app.verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||
|
|
@ -122,7 +122,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
abs_pos = 0;
|
||||
InitGenerator(args, gen);
|
||||
}
|
||||
if (verbosity >= 2) {
|
||||
if (app.verbosity >= 2) {
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
} else {
|
||||
|
|
@ -133,7 +133,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
if (tokens_generated_this_turn == prompt_size + 1) {
|
||||
// first token of response
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (verbosity >= 1) {
|
||||
if (app.verbosity >= 1) {
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
}
|
||||
|
|
@ -144,7 +144,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
|
||||
while (true) { // Loop until user quits.
|
||||
tokens_generated_this_turn = 0;
|
||||
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
|
||||
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
|
||||
if (!std::cin) return;
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||
|
|
@ -171,18 +171,17 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
}
|
||||
}
|
||||
|
||||
TimingInfo timing_info = {.verbosity = verbosity};
|
||||
RuntimeConfig runtime_config = {
|
||||
.verbosity = verbosity,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
||||
RuntimeConfig runtime_config = {.verbosity = app.verbosity,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
.use_spinning = app.spin};
|
||||
args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
runtime_config.image_tokens = image_tokens.get();
|
||||
prompt.insert(prompt.begin(), image_tokens->BatchSize(), 0);
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
|
|
@ -237,8 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
|
||||
ReplGemma(model, kv_cache, inference, app.verbosity, AcceptFunc(),
|
||||
app.eot_line);
|
||||
ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -95,11 +95,11 @@ struct LayerWeightsPtrs {
|
|||
config.model_dim},
|
||||
.qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads),
|
||||
config.qkv_dim},
|
||||
.linear_0_w = {"linear_0_w", config.model_dim,
|
||||
config.ff_hidden_dim},
|
||||
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
|
||||
.linear_1_w = {"linear_1_w", config.ff_hidden_dim,
|
||||
.linear_0_w = {"linear_0_w", config.ff_hidden_dim,
|
||||
config.model_dim},
|
||||
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
|
||||
.linear_1_w = {"linear_1_w", config.model_dim,
|
||||
config.ff_hidden_dim},
|
||||
.linear_1_b = {"linear_1_b", 1, config.model_dim},
|
||||
.layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim},
|
||||
.layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim},
|
||||
|
|
@ -349,14 +349,13 @@ struct ModelWeightsPtrs {
|
|||
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
|
||||
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
|
||||
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
|
||||
vit_img_embedding_kernel(
|
||||
"img_emb_kernel",
|
||||
config.patch_width * config.patch_width * 3,
|
||||
config.vit_model_dim),
|
||||
vit_img_embedding_kernel("img_emb_kernel",
|
||||
config.patch_width * config.patch_width * 3,
|
||||
config.vit_model_dim),
|
||||
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
|
||||
vit_img_head_bias("img_head_bias", 1, config.model_dim),
|
||||
vit_img_head_kernel("img_head_kernel", config.vit_model_dim,
|
||||
config.model_dim),
|
||||
vit_img_head_kernel("img_head_kernel", config.model_dim,
|
||||
config.vit_model_dim),
|
||||
scale_names(config.scale_names),
|
||||
weights_config(config) {
|
||||
c_layers.reserve(config.layer_configs.size());
|
||||
|
|
|
|||
|
|
@ -1011,14 +1011,14 @@ struct TestShortDotsT {
|
|||
// hence they require padding to one vector.
|
||||
const size_t padded_num = hwy::RoundUpTo(num, N);
|
||||
const size_t packed_num = CompressedArrayElements<Packed>(num);
|
||||
RowVectorBatch<float> raw_w(1, padded_num);
|
||||
RowVectorBatch<float> raw_v(1, padded_num);
|
||||
RowVectorBatch<Packed> weights(1, packed_num);
|
||||
RowVectorBatch<float> raw_w(Extents2D(1, padded_num));
|
||||
RowVectorBatch<float> raw_v(Extents2D(1, padded_num));
|
||||
RowVectorBatch<Packed> weights(Extents2D(1, packed_num));
|
||||
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
|
||||
RowVectorBatch<T> vectors(1, num);
|
||||
RowVectorBatch<T> vectors(Extents2D(1, num));
|
||||
const PackedSpan<T> v(vectors.Batch(0), num);
|
||||
|
||||
RowVectorBatch<double> bufs(1, num);
|
||||
RowVectorBatch<double> bufs(Extents2D(1, num));
|
||||
double* HWY_RESTRICT buf = bufs.Batch(0);
|
||||
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||
|
|
@ -1107,11 +1107,11 @@ void TestAllDot() {
|
|||
|
||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||
const size_t num = 24 * 1024;
|
||||
NestedPools pools(kMaxWorkers - 1, /*pin=*/1, BoundedSlice(0, 1),
|
||||
BoundedSlice(0, 1));
|
||||
RowVectorBatch<float> a(kMaxWorkers, num);
|
||||
RowVectorBatch<float> b(kMaxWorkers, num);
|
||||
RowVectorBatch<double> bufs(kMaxWorkers, num);
|
||||
NestedPools pools(kMaxWorkers - 1, /*pin=*/Tristate::kDefault,
|
||||
BoundedSlice(0, 1), BoundedSlice(0, 1));
|
||||
RowVectorBatch<float> a(Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<float> b(Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num));
|
||||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
|
||||
|
|
|
|||
309
ops/matmul-inl.h
309
ops/matmul-inl.h
|
|
@ -16,8 +16,9 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "compression/compress.h" // IWYU pragma: keep, b/conditionally used
|
||||
#include "ops/matmul.h" // IWYU pragma: export
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
|
|
@ -30,7 +31,7 @@
|
|||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
@ -53,38 +54,20 @@ constexpr size_t kRegCols = 4;
|
|||
// generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
|
||||
constexpr size_t kRegRows = kRegCols;
|
||||
|
||||
// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are
|
||||
// more efficient than f32 * f32 + f32 because they process twice as many lanes
|
||||
// at a time. Any combination of A and B can be bf16: activations may already be
|
||||
// bf16, and weights can be decompressed to bf16.
|
||||
//
|
||||
// The corresponding op is `ReorderWidenMulAccumulate`, and it is always
|
||||
// supported, but only useful if it returns a single vector of pairwise sums
|
||||
// `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReorderWidenMulAccumulate`
|
||||
// insteads return `a[1] * b[1]` in its `sum1` output. We cannot afford to keep
|
||||
// a `sum1` for each of the `kRegRows * kRegCols` C vectors, and it would be
|
||||
// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B
|
||||
// to bf16 if the native op is available. This will actually demote f32
|
||||
// activations to bf16. Otherwise, we decompress to f32 and use normal FMA.
|
||||
using MulT = hwy::If<HWY_NATIVE_DOT_BF16, BF16, float>;
|
||||
|
||||
// Loads two vectors at a time with element type MulT from a row of transposed
|
||||
// B. Called in a loop over col_ab. No bounds checking because `kRow` is
|
||||
// actually from B columns, which we checked is a multiple of `kRegCols`.
|
||||
// Loads two vectors at a time with element type hn::TFromD<DR> from a row of
|
||||
// transposed B. Called in a loop over col_ab. No bounds checking because
|
||||
// `kRow` is from B columns, which we checked is a multiple of `kRegCols`.
|
||||
template <size_t kRow, typename MatTB>
|
||||
class BRow {
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
|
||||
public:
|
||||
BRow(const Mat<const MatTB>& B, size_t row_b, size_t cols_c)
|
||||
// B.cols * C.cols is the total number of elements, required for
|
||||
// PackedSpan::BoundsCheck.
|
||||
: B_(MakeSpan(B.ptr, B.ofs + B.cols * cols_c)),
|
||||
B_ofs_(B.Row(row_b + kRow)) {}
|
||||
BRow(const ConstMat<MatTB>& B, size_t row_b)
|
||||
: B_(MakeSpan(B.ptr, B.ofs + B.Extents().Area())),
|
||||
B_ofs_(B.Row(HWY_MIN(row_b + kRow, B.Extents().rows - 1))) {}
|
||||
|
||||
template <class DM, class VM = hn::Vec<DM>>
|
||||
HWY_INLINE void Load2(DM d, size_t col_ab, VM& b0, VM& b1) const {
|
||||
static_assert(hwy::IsSame<hn::TFromD<DM>, MulT>());
|
||||
template <class DR, class VR = hn::Vec<DR>>
|
||||
HWY_INLINE void Load2(DR d, size_t col_ab, VR& b0, VR& b1) const {
|
||||
Decompress2(d, B_, B_ofs_ + col_ab, b0, b1);
|
||||
}
|
||||
|
||||
|
|
@ -93,11 +76,11 @@ class BRow {
|
|||
const size_t B_ofs_;
|
||||
};
|
||||
|
||||
// Loads *two* row vectors from A via `Decompress2`, multiplies element-wise
|
||||
// with `kRegRows` x 2 row vectors from transposed B, and adds them to
|
||||
// `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a subset of
|
||||
// the terms of the dot products that make up the MatMul result at `r,c`.
|
||||
// No-op for the bottom-most tile where kRow >= kNumRows.
|
||||
// Loads *two* row vectors from A via `Decompress2`, widens to f32, multiplies
|
||||
// element-wise with `kRegRows` x 2 row vectors from transposed B, and adds
|
||||
// them to `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a
|
||||
// subset of the terms of the dot products that make up the MatMul result at
|
||||
// `r,c`. No-op for the bottom-most rows whose `kRow >= kNumRows`.
|
||||
//
|
||||
// This approach is atypical because it requires a horizontal sum, for which we
|
||||
// introduce a fast and new(?) vector-length agnostic 'transpose', see
|
||||
|
|
@ -107,22 +90,24 @@ class BRow {
|
|||
// - `Decompress2` decompresses two vectors at a time;
|
||||
// - B is column-major, so unit-stride SIMD loads return a column, not values
|
||||
// from different columns, i.e. a row.
|
||||
// Both could be fixed in a packing stage, which is not implemented yet, and
|
||||
// might not be necessary otherwise. However, `ReorderWidenMulAccumulate` is
|
||||
// important for bf16 performance and incompatible with the conventional
|
||||
// approach, because its pairwise adds would add together unrelated terms.
|
||||
// By contrast, pairwise adds are fine when our C lanes are the terms of a
|
||||
// single dot product, which can be reordered or pre-reduced.
|
||||
// - `ReorderWidenMulAccumulate` is important for bf16 performance, but its
|
||||
// pairwise adds would add together unrelated terms.
|
||||
// The first two could be fixed in a packing stage, which is not implemented
|
||||
// yet, and might not be necessary otherwise. The third seems a fundamental
|
||||
// mismatch. However, pairwise adds are fine in our setting because C lanes are
|
||||
// the terms of a single dot product, which can be reordered or pre-reduced.
|
||||
template <size_t kRow, typename MatTA>
|
||||
class ALoadAccumulate {
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
|
||||
public:
|
||||
ALoadAccumulate(const Mat<const MatTA>& A, size_t row_ac, size_t batch_size)
|
||||
// A.cols * batch_size is the total number of elements, required for
|
||||
// PackedSpan::BoundsCheck.
|
||||
: A_(MakeSpan(A.ptr, A.ofs + A.cols * batch_size)),
|
||||
A_ofs_(A.Row(row_ac + kRow)) {}
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
// `First` and `Next` handle a single row of A, so the horizontal sums of
|
||||
// their `C0..3` are the (partial) dot products for 4 consecutive values in
|
||||
// one row of C.
|
||||
static_assert(kRegCols == 4);
|
||||
|
||||
ALoadAccumulate(const ConstMat<MatTA>& A, size_t row_ac)
|
||||
: A_(MakeSpan(A.ptr, A.ofs + A.Extents().Area())),
|
||||
A_ofs_(A.Row(HWY_MIN(row_ac + kRow, A.Extents().rows - 1))) {}
|
||||
|
||||
// First iteration, col_ab = 0: initialize C0..3 instead of updating them.
|
||||
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
|
||||
|
|
@ -161,20 +146,27 @@ class ALoadAccumulate {
|
|||
Decompress2(dm, A_, A_ofs_, a0, a1);
|
||||
|
||||
const DF df;
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::WidenMulPairwiseAdd(df, a0, b00);
|
||||
C1 = hn::WidenMulPairwiseAdd(df, a0, b10);
|
||||
C2 = hn::WidenMulPairwiseAdd(df, a0, b20);
|
||||
C3 = hn::WidenMulPairwiseAdd(df, a0, b30);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
// Native ReorderWidenMulAccumulate adds to C0..3 for free.
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
} else {
|
||||
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01));
|
||||
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11));
|
||||
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21));
|
||||
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -217,20 +209,31 @@ class ALoadAccumulate {
|
|||
Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
|
||||
|
||||
const DF df;
|
||||
hn::Vec<DF> unused_sum1 = hn::Zero(df);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
// Native ReorderWidenMulAccumulate adds to C0..3 for free.
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
} else {
|
||||
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a0, b00));
|
||||
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a0, b10));
|
||||
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a0, b20));
|
||||
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a0, b30));
|
||||
C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01));
|
||||
C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11));
|
||||
C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21));
|
||||
C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -356,116 +359,113 @@ class AddHorizontalSums {
|
|||
// Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a
|
||||
// *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c).
|
||||
// TODO: loop over sections instead of full rows and accumulate into `tile_c`.
|
||||
// `buf` is 16 vectors of thread-local storage.
|
||||
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulTile(const size_t batch_size, const Mat<const MatTA>& A,
|
||||
const Mat<const MatTB>& B, const size_t row_ac,
|
||||
const size_t row_b_col_c, const float scale,
|
||||
const float* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT buf, const Mat<float>& C) {
|
||||
// For 'decompressing' A and B into BF16 or float.
|
||||
const hn::ScalableTag<MulT> dm;
|
||||
using VM = hn::Vec<decltype(dm)>;
|
||||
const size_t NM = hn::Lanes(dm);
|
||||
HWY_INLINE void MatMulTile(const ConstMat<MatTA>& A, const size_t row_ac,
|
||||
const ConstMat<MatTB>& B, const size_t row_b_col_c,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT buf, const RowPtr<float>& C) {
|
||||
// Decompress A and B to which type, which will then be widened to f32,
|
||||
// multiplied, added once into f32, then promoted to f64 and accumulated.
|
||||
// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are
|
||||
// more efficient than f32 * f32 + f32 because they process twice as many
|
||||
// lanes at a time. If available, we definitely want to use them. Otherwise,
|
||||
// bf16 is still worthwhile if A (activations) are bf16: SFP weights are
|
||||
// cheaper to decode to bf16, relative to the minor extra cost of promoting
|
||||
// bf16 when multiplying. However, if A is f32, demoting to bf16 can be
|
||||
// expensive unless we also have native bf16 dot.
|
||||
using Raw = hwy::If<HWY_NATIVE_DOT_BF16 || !IsF32<MatTA>(), BF16, float>;
|
||||
const hn::ScalableTag<Raw> dr;
|
||||
using VR = hn::Vec<decltype(dr)>;
|
||||
const size_t NR = hn::Lanes(dr);
|
||||
|
||||
const Range1D cols_ab(0, A.Extents().cols);
|
||||
HWY_DASSERT(row_ac + kNumRows <= A.Extents().rows);
|
||||
HWY_DASSERT(row_b_col_c + kNumRows <= B.Extents().rows);
|
||||
HWY_DASSERT(cols_ab.end() % (2 * NR) == 0);
|
||||
|
||||
static_assert(kRegRows == 4);
|
||||
const BRow<0, MatTB> b_row0(B, row_b_col_c, C.cols);
|
||||
const BRow<1, MatTB> b_row1(B, row_b_col_c, C.cols);
|
||||
const BRow<2, MatTB> b_row2(B, row_b_col_c, C.cols);
|
||||
const BRow<3, MatTB> b_row3(B, row_b_col_c, C.cols);
|
||||
const BRow<0, MatTB> b_row0(B, row_b_col_c);
|
||||
const BRow<1, MatTB> b_row1(B, row_b_col_c);
|
||||
const BRow<2, MatTB> b_row2(B, row_b_col_c);
|
||||
const BRow<3, MatTB> b_row3(B, row_b_col_c);
|
||||
|
||||
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac);
|
||||
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac);
|
||||
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac);
|
||||
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac);
|
||||
|
||||
const hn::Repartition<float, decltype(dm)> df;
|
||||
const hn::Repartition<float, decltype(dr)> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF C00, C01, C02, C03;
|
||||
VF C10, C11, C12, C13;
|
||||
VF C20, C21, C22, C23;
|
||||
VF C30, C31, C32, C33;
|
||||
|
||||
size_t col_ab = cols_ab.begin();
|
||||
{ // First iteration initializes the `Crc` vectors.
|
||||
VM b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dm, /*col_ab=*/0, b00, b01);
|
||||
b_row1.Load2(dm, /*col_ab=*/0, b10, b11);
|
||||
b_row2.Load2(dm, /*col_ab=*/0, b20, b21);
|
||||
b_row3.Load2(dm, /*col_ab=*/0, b30, b31);
|
||||
VR b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dr, col_ab, b00, b01);
|
||||
b_row1.Load2(dr, col_ab, b10, b11);
|
||||
b_row2.Load2(dr, col_ab, b20, b21);
|
||||
b_row3.Load2(dr, col_ab, b30, b31);
|
||||
|
||||
a_row0.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
a_row0.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C00, C01, C02, C03);
|
||||
a_row1.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
a_row1.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C10, C11, C12, C13);
|
||||
a_row2.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
a_row2.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C20, C21, C22, C23);
|
||||
a_row3.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
a_row3.template First<kNumRows>(dr, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C30, C31, C32, C33);
|
||||
col_ab += 2 * NR;
|
||||
}
|
||||
|
||||
// `2 * NM` per iteration because `Load2` returns two vectors.
|
||||
// `2 * NR` per iteration because `Load2` returns two vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 2 * NM; col_ab <= A.cols - 2 * NM; col_ab += 2 * NM) {
|
||||
VM b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dm, col_ab, b00, b01);
|
||||
b_row1.Load2(dm, col_ab, b10, b11);
|
||||
b_row2.Load2(dm, col_ab, b20, b21);
|
||||
b_row3.Load2(dm, col_ab, b30, b31);
|
||||
for (; col_ab < cols_ab.end(); col_ab += 2 * NR) {
|
||||
VR b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dr, col_ab, b00, b01);
|
||||
b_row1.Load2(dr, col_ab, b10, b11);
|
||||
b_row2.Load2(dr, col_ab, b20, b21);
|
||||
b_row3.Load2(dr, col_ab, b30, b31);
|
||||
|
||||
a_row0.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
a_row0.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C00, C01, C02, C03);
|
||||
a_row1.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
a_row1.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C10, C11, C12, C13);
|
||||
a_row2.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
a_row2.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C20, C21, C22, C23);
|
||||
a_row3.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
a_row3.template Next<kNumRows>(dr, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C30, C31, C32, C33);
|
||||
}
|
||||
|
||||
// TODO: hoist into outer loop.
|
||||
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c;
|
||||
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.stride);
|
||||
float* HWY_RESTRICT C_tile = C.Row(row_ac) + row_b_col_c;
|
||||
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.Stride());
|
||||
|
||||
AddHorizontalSums<kNumRows>()(df, scale, C00, C01, C02, C03, C10, C11, C12,
|
||||
C13, C20, C21, C22, C23, C30, C31, C32, C33,
|
||||
buf, C_tile, C.stride);
|
||||
buf, C_tile, C.Stride());
|
||||
}
|
||||
|
||||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||
//
|
||||
// `A` is a row-major matrix of shape `(batch_size, A.cols)`.
|
||||
// `B` is transposed; `B.cols`, which must match `A.cols`, denotes the number of
|
||||
// rows in the original B, and `C.cols` the number of columns in the original B.
|
||||
//
|
||||
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
||||
// values. When `A` and/or `B` are from CompressedArray, `scale` should be the
|
||||
// product of their `.scale()` values, otherwise 1.0f.
|
||||
//
|
||||
// If `kAdd` is true, the row-vector `add` is added to each row of `C`,
|
||||
// otherwise `add` is ignored and can be nullptr. A scale for `add` is not
|
||||
// supported, so make sure its scale is 1.
|
||||
//
|
||||
// `C` is a row-major matrix of size `(batch_size, C.cols)`.
|
||||
//
|
||||
// Updates 4x4 tiles of C in parallel using a work-stealing thread pool.
|
||||
// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k.
|
||||
// Must not be called concurrently with the same `env`.
|
||||
template <bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
|
||||
const Mat<const MatTB>& B, const float scale,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const Mat<float>& C) {
|
||||
HWY_NOINLINE void MatMulImpl(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<float>& C) {
|
||||
// PROFILER_ZONE("Matmul");
|
||||
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty());
|
||||
HWY_DASSERT(A.cols == B.cols);
|
||||
HWY_DASSERT(A.Extents().cols == B.Extents().cols);
|
||||
const size_t batch_size = A.Extents().rows;
|
||||
HWY_DASSERT(C.Cols() % kRegCols == 0);
|
||||
HWY_DASSERT(C.Stride() >= C.Cols());
|
||||
HWY_DASSERT(B.Extents().rows == C.Cols());
|
||||
|
||||
// Must be a multiple of two vectors because we Decompress2.
|
||||
HWY_DASSERT(A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0);
|
||||
HWY_DASSERT(C.cols % kRegCols == 0);
|
||||
const float scale = A.scale * B.scale;
|
||||
|
||||
// We currently write C directly, which touches more memory than fits in L3.
|
||||
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
|
||||
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
|
||||
const size_t tilesX = C.cols / kRegCols;
|
||||
const size_t tilesX = C.Cols() / kRegCols;
|
||||
|
||||
env.Pool().Run(
|
||||
0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR {
|
||||
|
|
@ -481,24 +481,45 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
|
|||
HWY_DASSERT(num_rows != 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
MatMulTile<1, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
MatMulTile<1, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
|
||||
break;
|
||||
case 2:
|
||||
MatMulTile<2, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
MatMulTile<2, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
|
||||
break;
|
||||
case 3:
|
||||
MatMulTile<3, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
MatMulTile<3, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
|
||||
break;
|
||||
default:
|
||||
MatMulTile<4, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
MatMulTile<4, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||
//
|
||||
// `A` is a row-major matrix and `B` is transposed. Its `B.Extents().cols`,
|
||||
// which must match `A.Extents().cols`, is the number of rows in the original B.
|
||||
//
|
||||
// If `add` is non-null, the row-vector `add` is added to each row of `C`.
|
||||
// A scale for `add` is not supported, so make sure its scale is 1.
|
||||
//
|
||||
// `C` is a row-major matrix of size `(A.rows, C.Cols())` with support for
|
||||
// arbitrary strides.
|
||||
//
|
||||
// Updates 4x4 tiles of C in parallel using a work-stealing thread pool.
|
||||
// Typically `A.rows` is 1..512, `A.Extents().cols` and `B.Extents().rows` are
|
||||
// 3k or 24k. Must not be called concurrently with the same `env`.
|
||||
template <typename MatTA, typename MatTB>
|
||||
HWY_NOINLINE void MatMul(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<float>& C) {
|
||||
if (add) {
|
||||
MatMulImpl<true>(A, B, add, env, C);
|
||||
} else {
|
||||
MatMulImpl<false>(A, B, nullptr, env, C);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
55
ops/matmul.h
55
ops/matmul.h
|
|
@ -19,73 +19,22 @@
|
|||
#include <stddef.h>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "hwy/per_target.h" // VectorBytes
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Bundles ptr/size/stride arguments to simplify MatMul call sites. T can be
|
||||
// const or non-const. Create via ConstMat/MutableMat.
|
||||
// TODO(rays): Replace with MatPtr and get rid of stride, which is only != cols
|
||||
// in one place.
|
||||
template <typename T>
|
||||
struct Mat {
|
||||
bool NotEmpty() const {
|
||||
return ptr != nullptr && cols != 0 && stride >= cols;
|
||||
}
|
||||
size_t Row(size_t r) const { return ofs + stride * r; }
|
||||
|
||||
T* HWY_RESTRICT ptr;
|
||||
size_t cols;
|
||||
|
||||
// elements between rows, which is typically the same as `cols`.
|
||||
size_t stride;
|
||||
|
||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||
// pointer arithmetic.
|
||||
size_t ofs;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
|
||||
size_t ofs = 0) {
|
||||
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols, size_t stride,
|
||||
size_t ofs = 0) {
|
||||
return Mat<const T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<const T> ConstMat(Mat<T> mat) {
|
||||
return ConstMat(mat.ptr, mat.cols, mat.stride, mat.ofs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols) {
|
||||
return MutableMat(ptr, cols, cols);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols) {
|
||||
return ConstMat(ptr, cols, cols);
|
||||
}
|
||||
|
||||
// Allocations and threads, shared across MatMul calls.
|
||||
class MatMulEnv {
|
||||
public:
|
||||
MatMulEnv() : pools_(nullptr) {}
|
||||
explicit MatMulEnv(NestedPools& pools) : pools_(&pools) {
|
||||
const size_t N = hwy::VectorBytes() / sizeof(float);
|
||||
buf_ = RowVectorBatch<float>(pools.MaxWorkers(), 16 * N);
|
||||
buf_ = RowVectorBatch<float>(Extents2D(pools.MaxWorkers(), 16 * N));
|
||||
}
|
||||
|
||||
RowVectorBatch<float>& Buf() { return buf_; }
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
|
||||
#include "compression/compress.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -55,19 +56,23 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||
|
||||
template <typename MatT>
|
||||
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
|
||||
|
||||
// Generates inputs: deterministic, within max SfpStream range.
|
||||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
class MatPtr = std::unique_ptr<MatStorageT<MatT>>>
|
||||
MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) {
|
||||
template <typename MatT>
|
||||
MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
auto mat = std::make_unique<MatStorageT<MatT>>("test", kRows, kCols);
|
||||
auto mat =
|
||||
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
HWY_ASSERT(content);
|
||||
const float scale = SfpStream::kMax / (mat->NumElements() + offset);
|
||||
pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kCols; j++) {
|
||||
content[i * kCols + j] =
|
||||
static_cast<float>((i * kCols + j + offset) * scale);
|
||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
content[r * extents.cols + c] =
|
||||
static_cast<float>(r * extents.cols + c) * scale;
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -76,185 +81,173 @@ MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) {
|
|||
return mat;
|
||||
}
|
||||
|
||||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
class MatPtr = std::unique_ptr<MatStorageT<MatT>>>
|
||||
MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) {
|
||||
// extents describes the transposed matrix.
|
||||
template <typename MatT>
|
||||
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
MatPtr mat = std::make_unique<MatStorageT<MatT>>("test", kCols, kRows);
|
||||
auto mat =
|
||||
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
const float scale = SfpStream::kMax / (mat->NumElements() + offset);
|
||||
pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kCols; j++) {
|
||||
content[j * kRows + i] =
|
||||
static_cast<float>((i * kCols + j + offset) * scale);
|
||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
content[r * extents.cols + c] =
|
||||
static_cast<float>(c * extents.rows + r) * scale;
|
||||
}
|
||||
});
|
||||
|
||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||
// Arbitrary value, different from 1, must match GenerateMatHeap.
|
||||
// Arbitrary value, different from 1, must match GenerateMat.
|
||||
mat->set_scale(0.6f);
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
class MatPtr = std::unique_ptr<MatStorageT<MatT>>>
|
||||
MatPtr GenerateZeroMat(hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
auto mat = std::make_unique<MatStorageT<MatT>>("Array", kRows, kCols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
HWY_ASSERT(content);
|
||||
|
||||
pool.Run(0, kRows, [&](const size_t i, size_t thread) {
|
||||
hwy::ZeroBytes(&content[i * kCols], kCols * sizeof(content[0]));
|
||||
});
|
||||
|
||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||
mat->set_scale(1.2f); // Arbitrary value, different from 1.
|
||||
return mat;
|
||||
}
|
||||
|
||||
// Returns 1-norm, used for estimating tolerable numerical differences.
|
||||
double MaxColAbsSum(const float* HWY_RESTRICT a, size_t rows, size_t cols) {
|
||||
double MaxColAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
|
||||
double max_col_abs_sum = 0.0;
|
||||
for (size_t c = 0; c < cols; c++) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
double col_abs_sum = 0.0;
|
||||
for (size_t r = 0; r < rows; r++) {
|
||||
col_abs_sum += hwy::ScalarAbs(a[r * cols + c]);
|
||||
for (size_t r = 0; r < extents.rows; r++) {
|
||||
col_abs_sum += hwy::ScalarAbs(a[r * extents.cols + c]);
|
||||
}
|
||||
max_col_abs_sum = HWY_MAX(max_col_abs_sum, col_abs_sum);
|
||||
}
|
||||
return max_col_abs_sum;
|
||||
}
|
||||
|
||||
// B is already transposed.
|
||||
template <typename MatTA, typename MatTB>
|
||||
void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
|
||||
const MatTA* HWY_RESTRICT pa,
|
||||
const MatTB* HWY_RESTRICT pb_trans,
|
||||
const float* HWY_RESTRICT expected_c,
|
||||
const float* HWY_RESTRICT actual_c) {
|
||||
void AssertClose(const ConstMat<MatTA>& A, const ConstMat<MatTB>& B,
|
||||
const RowPtrF& C_slow, const RowPtrF& C) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t num_a = rows_ac * cols_ab;
|
||||
const size_t num_b = cols_c_rows_b * cols_ab;
|
||||
const size_t num_a = A.extents.Area();
|
||||
const size_t num_b = B.extents.Area();
|
||||
HWY_ASSERT(num_a % hn::Lanes(df) == 0); // for DecompressAndZeroPad
|
||||
HWY_ASSERT(num_b % hn::Lanes(df) == 0); // for DecompressAndZeroPad
|
||||
const size_t num_c = rows_ac * cols_c_rows_b;
|
||||
FloatPtr a = hwy::AllocateAligned<float>(num_a);
|
||||
FloatPtr b_trans = hwy::AllocateAligned<float>(num_b);
|
||||
HWY_ASSERT(a && b_trans);
|
||||
DecompressAndZeroPad(df, MakeSpan(pa, num_a), 0, a.get(), num_a);
|
||||
DecompressAndZeroPad(df, MakeSpan(pb_trans, num_b), 0, b_trans.get(), num_b);
|
||||
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
||||
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
|
||||
DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b);
|
||||
|
||||
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
|
||||
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
|
||||
const double norm = MaxColAbsSum(a.get(), A.Extents()) *
|
||||
MaxColAbsSum(b_trans.get(), B.Extents());
|
||||
// Dot(float,BF16) rounds both to BF16.
|
||||
using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>;
|
||||
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>());
|
||||
const double tolerance = 200.0 * norm * epsilon;
|
||||
|
||||
for (size_t idx = 0; idx < num_c; idx++) {
|
||||
const double expected_value = expected_c[idx];
|
||||
const double actual_value = actual_c[idx];
|
||||
for (size_t r = 0; r < A.extents.rows; r++) {
|
||||
const float* expected_row = C_slow.Row(r);
|
||||
const float* actual_row = C.Row(r);
|
||||
for (size_t c = 0; c < B.extents.rows; c++) {
|
||||
const double expected_value = static_cast<double>(expected_row[c]);
|
||||
const double actual_value = static_cast<double>(actual_row[c]);
|
||||
|
||||
if (!(expected_value - tolerance <= actual_value &&
|
||||
actual_value <= expected_value + tolerance)) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"expected[%lu]: %f, actual[%lu]: %f, norm %f eps %E tolerance %f\n",
|
||||
idx, expected_value, idx, actual_value, norm, epsilon, tolerance);
|
||||
HWY_ASSERT(0);
|
||||
if (!(expected_value - tolerance <= actual_value &&
|
||||
actual_value <= expected_value + tolerance)) {
|
||||
fprintf(
|
||||
stderr,
|
||||
"(%zu,%zu): expected %f, actual %f, norm %f eps %E tolerance %f\n",
|
||||
r, c, expected_value, actual_value, norm, epsilon, tolerance);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// B is already transposed.
|
||||
template <typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
|
||||
const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b_trans, const float scale,
|
||||
HWY_INLINE void MatMulSlow(const ConstMat<MatTA> A, const ConstMat<MatTB> B,
|
||||
const float* HWY_RESTRICT add_row, MatMulEnv& env,
|
||||
float* HWY_RESTRICT out) {
|
||||
const RowPtrF& C) {
|
||||
// MatTA can be any Packed except NuqStream because it uses pointer
|
||||
// arithmetic, because it is the second argument to Dot, which does not
|
||||
// support a v_ofs.
|
||||
static_assert(sizeof(MatTA) >= sizeof(BF16), "A matrix must be BF16/f32");
|
||||
const float scale = A.scale * B.scale;
|
||||
|
||||
const hn::ScalableTag<float> df; // lane type is ignored
|
||||
const PackedSpan<const MatTB> b_span =
|
||||
MakeSpan(b_trans, cols_a_rows_b * cols_bc);
|
||||
MakeSpan(B.ptr, B.ofs + B.extents.Area());
|
||||
const Extents2D C_extents(A.extents.rows, C.Cols());
|
||||
|
||||
StaticPartitionRowsAndCols(
|
||||
env.Pools(), rows_ac, cols_bc, sizeof(MatTB),
|
||||
[&](size_t /*node*/, hwy::ThreadPool& pool,
|
||||
const size_t /*worker_offset*/, const size_t row_begin,
|
||||
const size_t row_end, const size_t col_begin, const size_t col_end) {
|
||||
pool.Run(row_begin, row_end,
|
||||
[&](const uint64_t row, size_t /*thread*/) {
|
||||
for (size_t col = col_begin; col < col_end; ++col) {
|
||||
const float add = add_row ? add_row[col] : 0.0f;
|
||||
out[row * cols_bc + col] =
|
||||
scale * Dot(df, b_span, col * cols_a_rows_b,
|
||||
a + row * cols_a_rows_b, cols_a_rows_b) +
|
||||
add;
|
||||
}
|
||||
});
|
||||
env.Pools(), C_extents, sizeof(MatTB),
|
||||
[&](const Range2D& C_range, const TaskLocation& loc) {
|
||||
loc.cluster.Run(
|
||||
C_range.rows.begin(), C_range.rows.end(),
|
||||
[&](const uint64_t row, size_t /*thread*/) {
|
||||
float* HWY_RESTRICT C_row = C.Row(row);
|
||||
for (size_t row_b_col_c : C_range.cols) {
|
||||
const float add = add_row ? add_row[row_b_col_c] : 0.0f;
|
||||
C_row[row_b_col_c] =
|
||||
add + scale * Dot(df, b_span, row_b_col_c * B.extents.cols,
|
||||
A.ptr + A.Row(row), A.extents.cols);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
|
||||
size_t cols_bc, double elapsed) {
|
||||
const size_t num_b = cols_a_rows_b * cols_bc;
|
||||
void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
||||
const Extents2D& B_extents, double elapsed) {
|
||||
const size_t num_b = B_extents.Area();
|
||||
// 2x because of FMA.
|
||||
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
|
||||
elapsed, 2 * 1E-9 * rows_ac * num_b / elapsed);
|
||||
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
|
||||
}
|
||||
|
||||
template <size_t kRowsAC, size_t kColsARowsB, size_t kColsBC, bool kAdd,
|
||||
typename MatTA, typename MatTB = MatTA>
|
||||
void TestMatMul(MatMulEnv& env) {
|
||||
template <typename MatTA, typename MatTB = MatTA>
|
||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||
MatMulEnv& env) {
|
||||
hwy::ThreadPool& pool = env.Pool();
|
||||
const bool want_bench = kColsBC > 2000; // avoid spam for small matrices
|
||||
const bool want_bench = cols_bc > 2000; // avoid spam for small matrices
|
||||
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
||||
kRowsAC, kColsARowsB, kColsBC, kAdd, TypeName<MatTA>(),
|
||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<MatTA>(),
|
||||
TypeName<MatTB>());
|
||||
|
||||
std::unique_ptr<MatStorageT<MatTA>> a =
|
||||
GenerateMat<MatTA, kRowsAC, kColsARowsB>(0, pool);
|
||||
std::unique_ptr<MatStorageT<MatTB>> b_trans =
|
||||
GenerateTransposedMat<MatTB, kColsARowsB, kColsBC>(0, pool);
|
||||
FloatPtr c = hwy::AllocateAligned<float>(kRowsAC * kColsBC);
|
||||
HWY_ASSERT(c);
|
||||
const Extents2D A_extents(rows_ac, cols_a_rows_b);
|
||||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||
const Extents2D C_extents(rows_ac, cols_bc);
|
||||
|
||||
const float scale = a->scale() * b_trans->scale();
|
||||
std::unique_ptr<MatStorageT<float>> add;
|
||||
if (kAdd) {
|
||||
add = GenerateMat<float, 1, kColsBC>(0, pool);
|
||||
add->set_scale(1.0f);
|
||||
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
|
||||
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
|
||||
RowVectorBatch<float> c_slow_batch(C_extents);
|
||||
RowVectorBatch<float> c_batch(C_extents);
|
||||
HWY_ASSERT(a && b_trans);
|
||||
|
||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||
if (add) {
|
||||
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
|
||||
HWY_ASSERT(add_storage);
|
||||
add_storage->set_scale(1.0f);
|
||||
}
|
||||
|
||||
std::unique_ptr<MatStorageT<float>> c_slow =
|
||||
GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
|
||||
const auto A = ConstMatFromWeights(*a);
|
||||
const auto B = ConstMatFromWeights(*b_trans);
|
||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||
const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch);
|
||||
const RowPtrF C = RowPtrFromBatch(c_batch);
|
||||
|
||||
const double start_slow = hwy::platform::Now();
|
||||
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
|
||||
kAdd ? add->data() : nullptr, env, c_slow->data());
|
||||
MatMulSlow(A, B, add_row, env, C_slow);
|
||||
if (want_bench) {
|
||||
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,
|
||||
PrintSpeed("MatMulSlow", A_extents, B_extents,
|
||||
hwy::platform::Now() - start_slow);
|
||||
}
|
||||
|
||||
double min_elapsed = hwy::HighestValue<double>();
|
||||
for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) {
|
||||
const double start_tiled = hwy::platform::Now();
|
||||
MatMul<kAdd>(kRowsAC, ConstMat(a->data(), kColsARowsB),
|
||||
ConstMat(b_trans->data(), kColsARowsB), scale,
|
||||
kAdd ? add->data_scale1() : nullptr, env,
|
||||
MutableMat(c.get(), kColsBC));
|
||||
MatMul(A, B, add_row, env, C);
|
||||
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
|
||||
}
|
||||
if (want_bench) {
|
||||
PrintSpeed("MatMul", kRowsAC, kColsARowsB, kColsBC, min_elapsed);
|
||||
PrintSpeed("MatMul", A_extents, B_extents, min_elapsed);
|
||||
}
|
||||
|
||||
AssertClose(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(),
|
||||
c_slow->data(), c.get());
|
||||
AssertClose(A, B, C_slow, C);
|
||||
}
|
||||
|
||||
void TestAllMatMul() {
|
||||
|
|
@ -264,8 +257,9 @@ void TestAllMatMul() {
|
|||
return;
|
||||
}
|
||||
|
||||
NestedPools pools(4, /*pin=*/1);
|
||||
pools.StartSpinning();
|
||||
NestedPools pools(4, /*pin=*/Tristate::kDefault);
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
pools.MaybeStartSpinning(use_spinning);
|
||||
Allocator::Init(pools.Topology());
|
||||
MatMulEnv env(pools);
|
||||
|
||||
|
|
@ -273,52 +267,54 @@ void TestAllMatMul() {
|
|||
using SFP = SfpStream;
|
||||
|
||||
// large-scale test: batch_size=128 is better than 64 or 256 for SKX.
|
||||
TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<128, 3072, 24576, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<1, 24576, 3072, /*kAdd=*/false, F32, F32>(env);
|
||||
TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, F32>(env);
|
||||
// TestMatMul<F32, SFP>(128, 24576, 3072, /*add=*/false, env);
|
||||
// TestMatMul<F32, SFP>(128, 3072, 24576, /*add=*/false, env);
|
||||
TestMatMul<F32, F32>(1, 24576, 3072, /*add=*/false, env);
|
||||
TestMatMul<F32, F32>(1, 3072, 24576, /*add=*/false, env);
|
||||
TestMatMul<F32, SFP>(1, 24576, 3072, /*add=*/false, env);
|
||||
TestMatMul<F32, SFP>(1, 3072, 24576, /*add=*/false, env);
|
||||
|
||||
// medium-sized square test - temporarily disabled for faster testing.
|
||||
if constexpr (false) {
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<F32>(512, 512, 512, /*add=*/false, env);
|
||||
TestMatMul<BF16>(512, 512, 512, /*add=*/true, env);
|
||||
TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env);
|
||||
TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env);
|
||||
TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env);
|
||||
TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env);
|
||||
}
|
||||
|
||||
// minimal non-square test. kColsARowsB must be at least 2 vectors.
|
||||
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<F32>(35, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32>(4, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<BF16>(4, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<F32>(3, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16>(3, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32>(2, 128, 64, /*add=*/true, env);
|
||||
TestMatMul<BF16>(2, 128, 64, /*add=*/false, env);
|
||||
TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env);
|
||||
TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env);
|
||||
TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env);
|
||||
TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env);
|
||||
TestMatMul<F32>(1, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16>(1, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env);
|
||||
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env);
|
||||
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -389,7 +389,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
void TestRopeAndMulBy() {
|
||||
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
|
||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
RowVectorBatch<float> x(1, dim_qkv);
|
||||
RowVectorBatch<float> x(Extents2D(1, dim_qkv));
|
||||
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -45,20 +44,20 @@ class PaliGemmaTest : public ::testing::Test {
|
|||
std::string GemmaReply(const std::string& prompt_text) const;
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions);
|
||||
|
||||
std::unique_ptr<ImageTokens> image_tokens_;
|
||||
ImageTokens image_tokens_;
|
||||
};
|
||||
|
||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
Gemma& model = *(s_env->GetModel());
|
||||
image_tokens_ = std::make_unique<ImageTokens>(
|
||||
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
|
||||
image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
|
||||
model.GetModelConfig().model_dim));
|
||||
Image image;
|
||||
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
|
||||
model.GenerateImageTokens(runtime_config, image, *image_tokens_);
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens_);
|
||||
}
|
||||
|
||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||
|
|
@ -67,7 +66,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
|||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
.verbosity = 0,
|
||||
.gen = &s_env->MutableGen()};
|
||||
runtime_config.image_tokens = image_tokens_.get();
|
||||
runtime_config.image_tokens = &image_tokens_;
|
||||
size_t abs_pos = 0;
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
|
||||
|
|
@ -79,7 +78,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
|||
return true;
|
||||
};
|
||||
runtime_config.stream_token = stream_token,
|
||||
tokens.insert(tokens.begin(), image_tokens_->BatchSize(), 0);
|
||||
tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0);
|
||||
size_t num_tokens = tokens.size();
|
||||
size_t prefix_end = num_tokens;
|
||||
runtime_config.prefill_tbatch_size = num_tokens;
|
||||
|
|
|
|||
|
|
@ -162,20 +162,19 @@ static void BindMemory(void* ptr, size_t bytes, size_t node) {
|
|||
static void BindMemory(void*, size_t, size_t) {}
|
||||
#endif // GEMMA_NUMA && HWY_OS_LINUX
|
||||
|
||||
void BindTensor(NestedPools& nested, size_t rows, size_t cols,
|
||||
void BindTensor(NestedPools& nested, const Extents2D& extents,
|
||||
size_t bytes_per_col, void* ptr) {
|
||||
if (!Allocator::UseNUMA()) return;
|
||||
uint8_t* p8 = static_cast<uint8_t*>(ptr);
|
||||
const size_t bytes_per_row = cols * bytes_per_col;
|
||||
const size_t bytes_per_row = extents.cols * bytes_per_col;
|
||||
StaticPartitionRowsAndCols(
|
||||
nested, rows, cols, bytes_per_col,
|
||||
[&](size_t node, hwy::ThreadPool&, const size_t /*worker_offset*/,
|
||||
const size_t row_begin, const size_t row_end, const size_t col_begin,
|
||||
const size_t col_end) {
|
||||
for (size_t row = row_begin; row < row_end; ++row) {
|
||||
uint8_t* slice = p8 + row * bytes_per_row + col_begin * bytes_per_col;
|
||||
const size_t slice_size = (col_end - col_begin) * bytes_per_col;
|
||||
BindMemory(slice, slice_size, node);
|
||||
nested, extents, bytes_per_col,
|
||||
[&](const Range2D& r, const TaskLocation& loc) {
|
||||
for (size_t row : r.rows) {
|
||||
uint8_t* slice =
|
||||
p8 + row * bytes_per_row + r.cols.begin() * bytes_per_col;
|
||||
const size_t slice_size = r.cols.Num() * bytes_per_col;
|
||||
BindMemory(slice, slice_size, loc.node);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
#include <cstdlib> // std::aligned_alloc / _aligned_malloc
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -52,49 +53,6 @@ ByteStorageT AllocateSizeof() {
|
|||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||
}
|
||||
|
||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||
// This can be seen as a (batch_size x len) matrix.
|
||||
template <typename T>
|
||||
class RowVectorBatch {
|
||||
public:
|
||||
// Default ctor for Activations ctor.
|
||||
RowVectorBatch() : batch_size_(0), len_(0) {}
|
||||
// Main ctor, called from Activations::Allocate.
|
||||
RowVectorBatch(size_t batch_size, size_t len)
|
||||
: batch_size_(batch_size), len_(len) {
|
||||
mem_ = hwy::AllocateAligned<T>(batch_size * len);
|
||||
}
|
||||
|
||||
// Move-only
|
||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||
|
||||
size_t BatchSize() const { return batch_size_; }
|
||||
size_t Len() const { return len_; }
|
||||
|
||||
// Returns the given row vector of length `Len()`.
|
||||
T* Batch(size_t batch_idx) {
|
||||
HWY_DASSERT(batch_idx < batch_size_);
|
||||
return mem_.get() + batch_idx * len_;
|
||||
}
|
||||
const T* Batch(size_t batch_idx) const {
|
||||
HWY_DASSERT(batch_idx < batch_size_);
|
||||
return mem_.get() + batch_idx * len_;
|
||||
}
|
||||
|
||||
// For MatMul or other operations that process the entire batch at once.
|
||||
T* All() { return mem_.get(); }
|
||||
const T* Const() const { return mem_.get(); }
|
||||
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
|
||||
|
||||
private:
|
||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||
size_t batch_size_; // rows in the matrix
|
||||
size_t len_; // columns in the matrix = vector length
|
||||
};
|
||||
|
||||
// Stateful in order to know whether to bind to NUMA nodes. `Monostate` for
|
||||
// convenience - avoids passing around a reference.
|
||||
class Allocator {
|
||||
|
|
@ -167,10 +125,24 @@ class Allocator {
|
|||
static size_t alignment_;
|
||||
};
|
||||
|
||||
// For shorter arguments to the StaticPartitionRowsAndCols functor.
|
||||
struct TaskLocation {
|
||||
TaskLocation(size_t node, size_t package_idx, hwy::ThreadPool& cluster,
|
||||
size_t worker_offset)
|
||||
: node(node),
|
||||
package_idx(package_idx),
|
||||
cluster(cluster),
|
||||
worker_offset(worker_offset) {}
|
||||
size_t node;
|
||||
size_t package_idx;
|
||||
hwy::ThreadPool& cluster;
|
||||
const size_t worker_offset;
|
||||
};
|
||||
|
||||
// Used in MatMul and allocator.h. Defined here because it depends on
|
||||
// Allocator::Alignment().
|
||||
template <class Func>
|
||||
void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
|
||||
void StaticPartitionRowsAndCols(NestedPools& nested, Extents2D extents,
|
||||
size_t bytes_per_element, const Func& func) {
|
||||
// Both rows and cols must be a multiple of the alignment to avoid
|
||||
// touching remote pages.
|
||||
|
|
@ -183,14 +155,15 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
|
|||
hwy::ThreadPool& all_packages = nested.AllPackages();
|
||||
const size_t num_packages = all_packages.NumWorkers();
|
||||
const size_t cols_per_package =
|
||||
hwy::RoundUpTo(hwy::DivCeil(cols, num_packages), multiple);
|
||||
const size_t col_tasks = hwy::DivCeil(cols, cols_per_package);
|
||||
hwy::RoundUpTo(hwy::DivCeil(extents.cols, num_packages), multiple);
|
||||
const size_t col_tasks = hwy::DivCeil(extents.cols, cols_per_package);
|
||||
HWY_ASSERT(col_tasks <= num_packages);
|
||||
all_packages.Run(
|
||||
0, col_tasks, [&](uint64_t package_idx, size_t package_thread) {
|
||||
HWY_ASSERT(package_idx == package_thread); // one task per worker
|
||||
const size_t col_begin = package_idx * cols_per_package;
|
||||
const size_t col_end = HWY_MIN(col_begin + cols_per_package, cols);
|
||||
const Range1D col_range =
|
||||
MakeRange1D(col_begin, extents.cols, cols_per_package);
|
||||
|
||||
// Static partitioning of rows across the package's clusters. We assume
|
||||
// that row sharding is cheaper. In MatMul, results can indeed be
|
||||
|
|
@ -198,8 +171,8 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
|
|||
hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx);
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
const size_t rows_per_cluster =
|
||||
hwy::RoundUpTo(hwy::DivCeil(rows, num_clusters), multiple);
|
||||
const size_t row_tasks = hwy::DivCeil(rows, rows_per_cluster);
|
||||
hwy::RoundUpTo(hwy::DivCeil(extents.rows, num_clusters), multiple);
|
||||
const size_t row_tasks = hwy::DivCeil(extents.rows, rows_per_cluster);
|
||||
HWY_ASSERT(row_tasks <= num_clusters);
|
||||
all_clusters.Run(
|
||||
0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) {
|
||||
|
|
@ -217,11 +190,11 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
|
|||
nested.WorkerOffset(package_idx, cluster_idx);
|
||||
|
||||
const size_t row_begin = cluster_idx * rows_per_cluster;
|
||||
const size_t row_end =
|
||||
HWY_MIN(row_begin + rows_per_cluster, rows);
|
||||
const Range1D row_range =
|
||||
MakeRange1D(row_begin, extents.rows, rows_per_cluster);
|
||||
|
||||
func(node, cluster, worker_offset, row_begin, row_end, col_begin,
|
||||
col_end);
|
||||
func(Range2D(row_range, col_range),
|
||||
TaskLocation(node, package_idx, cluster, worker_offset));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
|||
10
util/app.h
10
util/app.h
|
|
@ -28,6 +28,7 @@
|
|||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h" // HWY_IS_ASAN
|
||||
|
||||
|
|
@ -59,7 +60,9 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
int verbosity;
|
||||
|
||||
size_t max_threads; // divided among the detected clusters
|
||||
int pin; // -1 = auto, 0 = no, 1 = yes
|
||||
Tristate pin; // pin threads?
|
||||
Tristate spin; // use spin waits?
|
||||
|
||||
// For BoundedSlice:
|
||||
size_t skip_packages;
|
||||
size_t max_packages;
|
||||
|
|
@ -81,7 +84,10 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
||||
visitor(max_threads, "num_threads", size_t{0},
|
||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
||||
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
visitor(pin, "pin", Tristate::kDefault,
|
||||
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
visitor(spin, "spin", Tristate::kDefault,
|
||||
"Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
// These can be used to partition CPU sockets/packages and their
|
||||
// clusters/CCXs across several program instances. The default is to use
|
||||
// all available resources.
|
||||
|
|
|
|||
32
util/args.h
32
util/args.h
|
|
@ -24,6 +24,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -62,6 +63,13 @@ class ArgsBase {
|
|||
}
|
||||
}
|
||||
|
||||
void operator()(const Tristate& t, const char* name,
|
||||
const Tristate& /*init*/, const char* /*help*/,
|
||||
int print_verbosity = 0) const {
|
||||
if (verbosity_ >= print_verbosity) {
|
||||
fprintf(stderr, "%-30s: %s\n", name, ToString(t));
|
||||
}
|
||||
}
|
||||
void operator()(const std::string& t, const char* name,
|
||||
const std::string& /*init*/, const char* /*help*/,
|
||||
int print_verbosity = 0) const {
|
||||
|
|
@ -127,13 +135,33 @@ class ArgsBase {
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool SetValue(const char* string, bool& t) {
|
||||
// Returns lower-cased string. Arg names are expected to be ASCII-only.
|
||||
static std::string ToLower(const char* string) {
|
||||
std::string value(string);
|
||||
// Lower-case. Arg names are expected to be ASCII-only.
|
||||
std::transform(value.begin(), value.end(), value.begin(), [](char c) {
|
||||
return 'A' <= c && c <= 'Z' ? c - ('Z' - 'z') : c;
|
||||
});
|
||||
return value;
|
||||
}
|
||||
|
||||
static bool SetValue(const char* string, Tristate& t) {
|
||||
const std::string value = ToLower(string);
|
||||
if (value == "true" || value == "on" || value == "1") {
|
||||
t = Tristate::kTrue;
|
||||
return true;
|
||||
} else if (value == "false" || value == "off" || value == "0") {
|
||||
t = Tristate::kFalse;
|
||||
return true;
|
||||
} else if (value == "default" || value == "auto" || value == "-1") {
|
||||
t = Tristate::kDefault;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool SetValue(const char* string, bool& t) {
|
||||
const std::string value = ToLower(string);
|
||||
if (value == "true" || value == "on" || value == "1") {
|
||||
t = true;
|
||||
return true;
|
||||
|
|
|
|||
205
util/basics.h
205
util/basics.h
|
|
@ -20,7 +20,8 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // HWY_IS_MSAN
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#if HWY_IS_MSAN
|
||||
|
|
@ -29,6 +30,19 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
||||
|
||||
static inline const char* ToString(Tristate t) {
|
||||
switch (t) {
|
||||
case Tristate::kFalse:
|
||||
return "false";
|
||||
case Tristate::kTrue:
|
||||
return "true";
|
||||
case Tristate::kDefault:
|
||||
return "default";
|
||||
}
|
||||
}
|
||||
|
||||
using BF16 = hwy::bfloat16_t;
|
||||
|
||||
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
|
||||
|
|
@ -46,6 +60,195 @@ struct TokenAndProb {
|
|||
float prob;
|
||||
};
|
||||
|
||||
// Entire size of a 2D array. By contrast, Range2D is a subrange.
|
||||
struct Extents2D {
|
||||
Extents2D() : rows(0), cols(0) {}
|
||||
Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
|
||||
HWY_DASSERT(rows != 0);
|
||||
HWY_DASSERT(cols != 0);
|
||||
}
|
||||
|
||||
size_t Area() const { return rows * cols; }
|
||||
|
||||
size_t rows;
|
||||
size_t cols;
|
||||
};
|
||||
|
||||
// Range2D consists of two Range1D.
|
||||
struct Range1D {
|
||||
Range1D(size_t begin, size_t end) : begin_(begin), end_(end) {
|
||||
HWY_DASSERT(begin < end);
|
||||
}
|
||||
size_t Num() const { return end_ - begin_; }
|
||||
|
||||
// Enable range-based for loops.
|
||||
class Iterator {
|
||||
public:
|
||||
Iterator(size_t i) : i_(i) {}
|
||||
|
||||
Iterator& operator++() {
|
||||
++i_;
|
||||
return *this;
|
||||
}
|
||||
bool operator!=(const Iterator& other) const { return i_ != other.i_; }
|
||||
size_t operator*() const { return i_; }
|
||||
// Enable using begin() directly as a size_t.
|
||||
operator size_t() const { return i_; }
|
||||
|
||||
private:
|
||||
size_t i_;
|
||||
};
|
||||
Iterator begin() const { return Iterator(begin_); }
|
||||
Iterator end() const { return Iterator(end_); }
|
||||
|
||||
const size_t begin_;
|
||||
const size_t end_;
|
||||
};
|
||||
|
||||
static inline Range1D MakeRange1D(size_t begin, size_t end, size_t max_size) {
|
||||
return Range1D(begin, HWY_MIN(begin + max_size, end));
|
||||
}
|
||||
|
||||
// In MatMul, the two axes are used independently, hence we do not define
|
||||
// Range2D as a top-left and extents.
|
||||
struct Range2D {
|
||||
Range2D(Range1D rows, Range1D cols) : rows(rows), cols(cols) {}
|
||||
const Range1D rows;
|
||||
const Range1D cols;
|
||||
};
|
||||
|
||||
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
||||
// it is always float and does not support compressed T, but does support an
|
||||
// arbitrary stride >= cols.
|
||||
template <typename T>
|
||||
class RowPtr {
|
||||
public:
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols)
|
||||
: row0_(row0), cols_(cols), stride_(cols) {}
|
||||
|
||||
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
||||
size_t Cols() const { return cols_; }
|
||||
|
||||
size_t Stride() const { return stride_; }
|
||||
void SetStride(size_t stride) {
|
||||
HWY_DASSERT(stride >= Cols());
|
||||
stride_ = stride;
|
||||
}
|
||||
|
||||
private:
|
||||
T* HWY_RESTRICT row0_;
|
||||
size_t stride_;
|
||||
size_t cols_;
|
||||
};
|
||||
|
||||
using RowPtrF = RowPtr<float>;
|
||||
|
||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
||||
// the memory.
|
||||
template <typename T>
|
||||
class RowVectorBatch {
|
||||
public:
|
||||
// Default ctor for Activations ctor.
|
||||
RowVectorBatch() = default;
|
||||
// Main ctor, called from Activations::Allocate.
|
||||
RowVectorBatch(Extents2D extents) : extents_(extents) {
|
||||
mem_ = hwy::AllocateAligned<T>(extents_.rows * extents_.cols);
|
||||
}
|
||||
|
||||
// Move-only
|
||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||
|
||||
size_t BatchSize() const { return extents_.rows; }
|
||||
size_t Cols() const { return extents_.cols; }
|
||||
Extents2D Extents() const { return extents_; }
|
||||
|
||||
// Returns the given row vector of length `Cols()`.
|
||||
T* Batch(size_t batch_idx) {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * Cols();
|
||||
}
|
||||
const T* Batch(size_t batch_idx) const {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * Cols();
|
||||
}
|
||||
|
||||
// For MatMul or other operations that process the entire batch at once.
|
||||
// TODO: remove once we only use Mat.
|
||||
T* All() { return mem_.get(); }
|
||||
const T* Const() const { return mem_.get(); }
|
||||
size_t NumBytes() const { return BatchSize() * Cols() * sizeof(T); }
|
||||
|
||||
private:
|
||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||
Extents2D extents_;
|
||||
};
|
||||
|
||||
// Used for the A and B arguments of `MatMul`, which are always const.
|
||||
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
|
||||
// `ofs` required for compressed T.
|
||||
template <typename T>
|
||||
struct ConstMat {
|
||||
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0)
|
||||
: ptr(ptr), extents(extents), ofs(ofs) {
|
||||
HWY_DASSERT(ptr != nullptr);
|
||||
}
|
||||
// TODO: support stride for page alignment.
|
||||
size_t Row(size_t r) const {
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (r >= extents.rows) {
|
||||
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
|
||||
}
|
||||
}
|
||||
return ofs + extents.cols * r;
|
||||
}
|
||||
|
||||
const Extents2D& Extents() const { return extents; }
|
||||
|
||||
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
||||
// subrange of the original rows starting at row 0.
|
||||
void ShrinkRows(size_t rows) {
|
||||
HWY_ASSERT(rows <= extents.rows);
|
||||
extents.rows = rows;
|
||||
}
|
||||
|
||||
const T* HWY_RESTRICT ptr;
|
||||
Extents2D extents;
|
||||
|
||||
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
||||
// values. MatFromWeights sets this from `MatPtr`.
|
||||
float scale = 1.0f;
|
||||
|
||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||
// pointer arithmetic.
|
||||
size_t ofs;
|
||||
};
|
||||
|
||||
// For deducing T.
|
||||
template <typename T>
|
||||
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
|
||||
size_t ofs = 0) {
|
||||
return ConstMat<T>(ptr, extents, ofs);
|
||||
}
|
||||
|
||||
// For A argument to MatMul (activations).
|
||||
template <typename T>
|
||||
ConstMat<T> ConstMatFromBatch(size_t batch_size,
|
||||
const RowVectorBatch<T>& row_vectors) {
|
||||
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
|
||||
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
|
||||
Extents2D(batch_size, row_vectors.Cols()));
|
||||
}
|
||||
|
||||
// For C argument to MatMul.
|
||||
template <typename T>
|
||||
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
|
||||
return RowPtr<T>(row_vectors.All(), row_vectors.Cols());
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
|
||||
|
|
|
|||
|
|
@ -0,0 +1,400 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "util/threading.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <atomic>
|
||||
#include <memory> // std::make_unique
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for container detection, do not remove
|
||||
#include "util/basics.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Sort T := packages/clusters by descending 'size' so that users who only use
|
||||
// one Group get the largest.
|
||||
template <class T>
|
||||
static void SortByDescendingSize(std::vector<T>& groups) {
|
||||
std::sort(groups.begin(), groups.end(),
|
||||
[](const T& a, const T& b) { return a.Size() > b.Size(); });
|
||||
}
|
||||
|
||||
BoundedTopology::BoundedTopology(BoundedSlice package_slice,
|
||||
BoundedSlice cluster_slice,
|
||||
BoundedSlice lp_slice) {
|
||||
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
|
||||
LPS enabled_lps;
|
||||
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
|
||||
const size_t num_lps = hwy::TotalLogicalProcessors();
|
||||
fprintf(stderr,
|
||||
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
|
||||
num_lps);
|
||||
for (size_t lp = 0; lp < num_lps; ++lp) {
|
||||
enabled_lps.Set(lp);
|
||||
}
|
||||
}
|
||||
|
||||
// Without threading support, only keep the first enabled LP; it might still
|
||||
// make sense to pin the main thread to avoid migrations.
|
||||
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
|
||||
HWY_ASSERT(enabled_lps.Any());
|
||||
const size_t lp = enabled_lps.First();
|
||||
enabled_lps = LPS();
|
||||
enabled_lps.Set(lp);
|
||||
fprintf(stderr,
|
||||
"Warning, threads not supported, using only the main thread\n.");
|
||||
}
|
||||
|
||||
#if !GEMMA_DISABLE_TOPOLOGY
|
||||
if (HWY_LIKELY(!topology_.packages.empty())) {
|
||||
InitFromTopology(enabled_lps, package_slice, cluster_slice);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Topology unknown or no packages with enabled LPs: create a single
|
||||
// package with one cluster, and one node.
|
||||
if (HWY_UNLIKELY(NumPackages() == 0)) {
|
||||
InitFromSlice(enabled_lps, lp_slice);
|
||||
}
|
||||
|
||||
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
|
||||
}
|
||||
|
||||
// Topology is unknown, rely on OS affinity and user-specified slice.
|
||||
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
||||
BoundedSlice lp_slice) {
|
||||
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
|
||||
// we honor both the OS affinity and the user-specified slice. Note that
|
||||
// this can be used to exclude hyperthreads because Linux groups LPs by
|
||||
// sibling index. For example, the first `num_cores` are not siblings.
|
||||
const size_t detected = enabled_lps.Count();
|
||||
size_t enabled_idx = 0;
|
||||
enabled_lps.Foreach([&](size_t lp) {
|
||||
if (lp_slice.Contains(detected, enabled_idx++)) {
|
||||
AddLP(lp);
|
||||
}
|
||||
});
|
||||
|
||||
// lp_slice can only reduce the number of `enabled_lps`, and not below 1.
|
||||
HWY_ASSERT(num_workers_ != 0);
|
||||
}
|
||||
|
||||
BoundedTopology::Cluster::Cluster(const LPS& enabled_lps,
|
||||
const std::vector<hwy::Topology::LP>& all_lps,
|
||||
const hwy::Topology::Cluster& tcluster) {
|
||||
bool is_first_lp = true;
|
||||
|
||||
tcluster.lps.Foreach([&](size_t lp) {
|
||||
// Skip if not first-hyperthread or disabled.
|
||||
if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
|
||||
|
||||
AddLP(lp);
|
||||
|
||||
// Set `node` once, and ensure subsequent nodes match - we assume there
|
||||
// is only one NUMA node per cluster.
|
||||
const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
|
||||
if (is_first_lp) {
|
||||
is_first_lp = false;
|
||||
node_ = lp_node;
|
||||
} else {
|
||||
static bool warned = false;
|
||||
if (lp_node != node_ && !warned) {
|
||||
warned = true;
|
||||
fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n",
|
||||
lp, lp_node, node_);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// NOTE: caller is responsible for checking whether `clusters` is empty.
|
||||
BoundedTopology::Package::Package(const LPS& enabled_lps,
|
||||
const hwy::Topology& topology,
|
||||
size_t package_idx,
|
||||
BoundedSlice cluster_slice) {
|
||||
const hwy::Topology::Package& tpackage = topology.packages[package_idx];
|
||||
// Populate `clusters` with the subset of clusters in `cluster_slice` that
|
||||
// have any enabled LPs. If `clusters` remains empty, the caller will
|
||||
// skip this `Package`.
|
||||
clusters.reserve(cluster_slice.Num(tpackage.clusters.size()));
|
||||
cluster_slice.Foreach(
|
||||
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
|
||||
const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx];
|
||||
Cluster cluster(enabled_lps, topology.lps, tcluster);
|
||||
// Skip if empty, i.e. too few `enabled_lps`.
|
||||
if (HWY_LIKELY(cluster.Size() != 0)) {
|
||||
clusters.push_back(std::move(cluster));
|
||||
}
|
||||
});
|
||||
SortByDescendingSize(clusters);
|
||||
}
|
||||
|
||||
#if !GEMMA_DISABLE_TOPOLOGY
|
||||
|
||||
static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) {
|
||||
LPS cores;
|
||||
lps.Foreach([&](size_t lp) {
|
||||
if (topology.lps[lp].smt == 0) cores.Set(lp);
|
||||
});
|
||||
return cores.Count();
|
||||
}
|
||||
|
||||
// Scans hwy::Topology for clusters and their size, for use by topology_string_.
|
||||
static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters,
|
||||
size_t& max_tcluster_cores,
|
||||
size_t& max_tcluster_lps) {
|
||||
max_tclusters = 0;
|
||||
max_tcluster_cores = 0;
|
||||
max_tcluster_lps = 0;
|
||||
for (size_t package_idx = 0; package_idx < topology_.packages.size();
|
||||
++package_idx) {
|
||||
const std::vector<hwy::Topology::Cluster>& tclusters =
|
||||
topology_.packages[package_idx].clusters;
|
||||
max_tclusters = HWY_MAX(max_tclusters, tclusters.size());
|
||||
size_t tcluster_cores = 0;
|
||||
size_t tcluster_lps = 0;
|
||||
for (size_t cluster_idx = 0; cluster_idx < tclusters.size();
|
||||
++cluster_idx) {
|
||||
const size_t cores = CoresFromLPs(tclusters[cluster_idx].lps, topology_);
|
||||
const size_t lps = tclusters[cluster_idx].lps.Count();
|
||||
tcluster_cores = HWY_MAX(tcluster_cores, cores);
|
||||
tcluster_lps = HWY_MAX(tcluster_lps, lps);
|
||||
}
|
||||
|
||||
if (tclusters.size() > 1 && tcluster_cores > 8) {
|
||||
fprintf(stderr,
|
||||
"Package %zu: multiple clusters with max size %zu, whereas CCX "
|
||||
"only have 8, may indicate a bug in hwy::Topology.\n",
|
||||
package_idx, tcluster_cores);
|
||||
}
|
||||
max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores);
|
||||
max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps);
|
||||
}
|
||||
HWY_ASSERT(max_tclusters != 0);
|
||||
HWY_ASSERT(max_tcluster_cores != 0);
|
||||
HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores);
|
||||
}
|
||||
|
||||
// Main part of ctor, called when topology is known.
|
||||
void BoundedTopology::InitFromTopology(const LPS& enabled_lps,
|
||||
BoundedSlice package_slice,
|
||||
BoundedSlice cluster_slice) {
|
||||
size_t max_tclusters, max_tcluster_cores, max_tcluster_lps;
|
||||
ScanTClusters(topology_, max_tclusters, max_tcluster_cores, max_tcluster_lps);
|
||||
|
||||
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
|
||||
package_slice.Foreach(
|
||||
"package", topology_.packages.size(), [&](size_t package_idx) {
|
||||
Package package(enabled_lps, topology_, package_idx, cluster_slice);
|
||||
// Skip if empty, i.e. too few `enabled_lps`.
|
||||
if (HWY_LIKELY(!package.clusters.empty())) {
|
||||
packages_.push_back(std::move(package));
|
||||
}
|
||||
});
|
||||
if (NumPackages() == 0) return;
|
||||
SortByDescendingSize(packages_);
|
||||
|
||||
// Remember NUMA nodes that we are actually using (not just enabled).
|
||||
for (const Package& p : packages_) {
|
||||
for (const Cluster& c : p.clusters) {
|
||||
nodes_.Set(c.Node());
|
||||
}
|
||||
}
|
||||
|
||||
// Scan for max BoundedTopology clusters and their size, for topology_string_.
|
||||
size_t all_max_cluster_size = 0;
|
||||
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
|
||||
size_t max_cluster_size = 0;
|
||||
for (size_t cluster_idx = 0; cluster_idx < NumClusters(package_idx);
|
||||
++cluster_idx) {
|
||||
max_cluster_size = HWY_MAX(max_cluster_size,
|
||||
GetCluster(package_idx, cluster_idx).Size());
|
||||
}
|
||||
if (NumClusters(package_idx) > 1 && max_cluster_size > 8) {
|
||||
fprintf(stderr,
|
||||
"Package %zu: multiple clusters with max size %zu, whereas CCX "
|
||||
"only have 8, may indicate a bug in BoundedTopology.\n",
|
||||
package_idx, max_cluster_size);
|
||||
}
|
||||
all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size);
|
||||
}
|
||||
|
||||
snprintf(topology_string_, sizeof(topology_string_),
|
||||
"%zuS %zuX %zuC %zuH, using %zuS %zuX %zuC (nodes=%zu)",
|
||||
topology_.packages.size(), max_tclusters, max_tcluster_cores,
|
||||
max_tcluster_lps / max_tcluster_cores, packages_.size(),
|
||||
NumClusters(0), all_max_cluster_size, nodes_.Count());
|
||||
}
|
||||
|
||||
#endif // !GEMMA_DISABLE_TOPOLOGY
|
||||
|
||||
void BoundedTopology::InitFromSlice(const LPS& enabled_lps,
|
||||
BoundedSlice lp_slice) {
|
||||
packages_.push_back(Package(enabled_lps, lp_slice));
|
||||
|
||||
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
|
||||
GetCluster(0, 0).Size());
|
||||
|
||||
// Assume a single NUMA node.
|
||||
nodes_.Set(0);
|
||||
HWY_ASSERT(NumNodes() == 1);
|
||||
}
|
||||
|
||||
static PoolPtr MakePool(size_t num_workers) {
|
||||
// `ThreadPool` expects the number of threads to create, which is one less
|
||||
// than the number of workers, but avoid underflow if zero.
|
||||
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
|
||||
return std::make_unique<hwy::ThreadPool>(num_threads);
|
||||
}
|
||||
|
||||
static bool InContainer() {
|
||||
return false;}
|
||||
|
||||
class NestedPools::Pinning {
|
||||
public:
|
||||
Pinning(Tristate pin, const BoundedTopology& topology) {
|
||||
if (pin == Tristate::kDefault) {
|
||||
// Pinning is unreliable inside containers because the hypervisor might
|
||||
// periodically change our affinity mask, or other processes might also
|
||||
// pin themselves to the same LPs.
|
||||
pin = InContainer() ? Tristate::kFalse : Tristate::kTrue;
|
||||
}
|
||||
want_pin_ = (pin == Tristate::kTrue);
|
||||
}
|
||||
|
||||
// If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`,
|
||||
// and sets `any_error_` if any fails.
|
||||
void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) {
|
||||
if (HWY_UNLIKELY(!want_pin_)) return;
|
||||
|
||||
const std::vector<size_t> lps = cluster.LPVector();
|
||||
HWY_ASSERT(pool->NumWorkers() <= lps.size());
|
||||
pool->Run(
|
||||
0, pool->NumWorkers(),
|
||||
[this, &pool, &lps](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each worker has one task
|
||||
if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) {
|
||||
fprintf(stderr,
|
||||
"Pinning failed for task %zu of %zu to %zu (size %zu)\n",
|
||||
task, pool->NumWorkers(), lps[task], lps.size());
|
||||
(void)any_error_.test_and_set();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
bool WantPin() const { return want_pin_; }
|
||||
|
||||
// Called ONCE after all MaybePin because it invalidates the error status.
|
||||
bool AllPinned() {
|
||||
// If !want_pin_, MaybePin will return without setting any_error_, but in
|
||||
// that case we still want to return false to avoid spinning.
|
||||
// .test() was only added in C++20, so we use .test_and_set() instead.
|
||||
return want_pin_ && !any_error_.test_and_set();
|
||||
}
|
||||
|
||||
private:
|
||||
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
|
||||
bool want_pin_; // set in ctor
|
||||
}; // Pinning
|
||||
|
||||
// Used to divide max_threads and max_workers_per_package across packages and
|
||||
// clusters. Ensures small upper bounds are respected.
|
||||
static size_t DivideMaxAcross(const size_t max, const size_t instances) {
|
||||
// No limit.
|
||||
if (max == 0) return 0;
|
||||
// We have enough to distribute.
|
||||
if (max >= instances) return max / instances;
|
||||
// Use max as the upper bound for each instance because division would return
|
||||
// zero, which means 'unlimited'.
|
||||
return max;
|
||||
}
|
||||
|
||||
NestedPools::NestedPools(size_t max_threads, Tristate pin,
|
||||
BoundedSlice package_slice, BoundedSlice cluster_slice,
|
||||
BoundedSlice lp_slice)
|
||||
: topology_(package_slice, cluster_slice, lp_slice) {
|
||||
Pinning pinning(pin, topology_);
|
||||
packages_.resize(topology_.NumPackages());
|
||||
all_packages_ = MakePool(packages_.size());
|
||||
const size_t max_workers_per_package =
|
||||
DivideMaxAcross(max_threads, packages_.size());
|
||||
// Each worker in all_packages_, including the main thread, will be the
|
||||
// calling thread of an all_clusters->Run, and hence pinned to one of the
|
||||
// `cluster.lps` if `pin`.
|
||||
all_packages_->Run(
|
||||
0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) {
|
||||
HWY_ASSERT(package_idx == thread); // each thread has one task
|
||||
packages_[package_idx] = Package(
|
||||
topology_, package_idx, max_workers_per_package, pinning, lp_slice);
|
||||
});
|
||||
|
||||
all_pinned_ = pinning.AllPinned();
|
||||
pin_string_ = all_pinned_ ? "pinned"
|
||||
: pinning.WantPin() ? "pinning failed"
|
||||
: "pinning skipped";
|
||||
|
||||
// For mapping package/cluster/thread to noncontiguous TLS indices, in case
|
||||
// cluster/thread counts differ.
|
||||
HWY_ASSERT(!packages_.empty() && packages_.size() <= 16);
|
||||
for (const Package& p : packages_) {
|
||||
max_clusters_per_package_ =
|
||||
HWY_MAX(max_clusters_per_package_, p.NumClusters());
|
||||
max_workers_per_cluster_ =
|
||||
HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster());
|
||||
}
|
||||
HWY_ASSERT(max_clusters_per_package_ >= 1);
|
||||
HWY_ASSERT(max_clusters_per_package_ <= 64);
|
||||
HWY_ASSERT(max_workers_per_cluster_ >= 1);
|
||||
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
||||
}
|
||||
|
||||
// `max_or_zero` == 0 means no limit.
|
||||
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
||||
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
||||
}
|
||||
|
||||
NestedPools::Package::Package(const BoundedTopology& topology,
|
||||
size_t package_idx,
|
||||
size_t max_workers_per_package, Pinning& pinning,
|
||||
BoundedSlice lp_slice) {
|
||||
// Pre-allocate because elements are set concurrently.
|
||||
clusters_.resize(topology.NumClusters(package_idx));
|
||||
const size_t max_workers_per_cluster =
|
||||
DivideMaxAcross(max_workers_per_package, clusters_.size());
|
||||
|
||||
all_clusters_ = MakePool(clusters_.size());
|
||||
// Parallel so we also pin the calling worker in `all_clusters` to
|
||||
// `cluster.lps`.
|
||||
all_clusters_->Run(
|
||||
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
|
||||
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
||||
const BoundedTopology::Cluster& cluster =
|
||||
topology.GetCluster(package_idx, cluster_idx);
|
||||
clusters_[cluster_idx] =
|
||||
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
|
||||
// Pin workers AND the calling thread from `all_clusters`.
|
||||
pinning.MaybePin(cluster, clusters_[cluster_idx]);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
294
util/threading.h
294
util/threading.h
|
|
@ -17,14 +17,12 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <memory> // std::unique_ptr
|
||||
#include <utility> // std::move
|
||||
#include <memory> // std::unique_ptr
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
|
||||
|
|
@ -78,6 +76,10 @@ class BoundedSlice {
|
|||
// "LP" is a logical processor, a 0-based index passed to the OS.
|
||||
using LPS = hwy::LogicalProcessorSet;
|
||||
|
||||
// We want vectors of hwy::ThreadPool, which is unfortunately not movable,
|
||||
// hence we wrap them in unique_ptr.
|
||||
using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
|
||||
|
||||
// Wraps hwy::Topology and only keeps the subset of packages and clusters
|
||||
// apportioned by BoundedSlice, further limited by the OS affinity mask.
|
||||
// NOTE: if topology is unknown or the OS affinity is too restrictive, we fall
|
||||
|
|
@ -85,96 +87,18 @@ using LPS = hwy::LogicalProcessorSet;
|
|||
class BoundedTopology {
|
||||
public:
|
||||
BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice,
|
||||
BoundedSlice lp_slice) {
|
||||
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
|
||||
LPS enabled_lps;
|
||||
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
|
||||
const size_t num_lps = hwy::TotalLogicalProcessors();
|
||||
fprintf(
|
||||
stderr,
|
||||
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
|
||||
num_lps);
|
||||
for (size_t lp = 0; lp < num_lps; ++lp) {
|
||||
enabled_lps.Set(lp);
|
||||
}
|
||||
}
|
||||
|
||||
// Without threading support, only keep the first enabled LP; it might still
|
||||
// make sense to pin the main thread.
|
||||
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
|
||||
HWY_ASSERT(enabled_lps.Any());
|
||||
const size_t lp = enabled_lps.First();
|
||||
enabled_lps = LPS();
|
||||
enabled_lps.Set(lp);
|
||||
}
|
||||
|
||||
#if !GEMMA_DISABLE_TOPOLOGY
|
||||
if (HWY_LIKELY(!topology_.packages.empty())) {
|
||||
InitFromTopology(enabled_lps, package_slice, cluster_slice);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Topology unknown, disabled or no packages with enabled LPs: create a
|
||||
// single package with one cluster, and one node.
|
||||
if (HWY_UNLIKELY(NumPackages() == 0)) {
|
||||
InitFromSlice(enabled_lps, lp_slice);
|
||||
}
|
||||
|
||||
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
|
||||
}
|
||||
BoundedSlice lp_slice);
|
||||
|
||||
size_t NumPackages() const { return packages_.size(); }
|
||||
const char* TopologyString() const { return topology_string_; }
|
||||
size_t NumNodes() const { return nodes_.Count(); }
|
||||
const char* TopologyString() const { return topology_string_; }
|
||||
|
||||
class Cluster {
|
||||
public:
|
||||
// Topology is unknown, rely on OS affinity and user-specified slice.
|
||||
Cluster(const LPS& enabled_lps, BoundedSlice lp_slice) {
|
||||
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
|
||||
// we honor both the OS affinity and the user-specified slice. Note that
|
||||
// this can be used to exclude hyperthreads because Linux groups LPs by
|
||||
// sibling index. For example, the first `num_cores` are not siblings.
|
||||
const size_t detected = enabled_lps.Count();
|
||||
size_t enabled_idx = 0;
|
||||
enabled_lps.Foreach([&](size_t lp) {
|
||||
if (lp_slice.Contains(detected, enabled_idx++)) {
|
||||
AddLP(lp);
|
||||
}
|
||||
});
|
||||
|
||||
// lp_slice can only reduce the number of `enabled_lps`, and not below 1.
|
||||
HWY_ASSERT(num_workers_ != 0);
|
||||
}
|
||||
|
||||
Cluster(const LPS& enabled_lps, BoundedSlice lp_slice);
|
||||
Cluster(const LPS& enabled_lps,
|
||||
const std::vector<hwy::Topology::LP>& all_lps,
|
||||
const hwy::Topology::Cluster& tcluster) {
|
||||
bool is_first_lp = true;
|
||||
|
||||
tcluster.lps.Foreach([&](size_t lp) {
|
||||
// Skip if not first-hyperthread or disabled.
|
||||
if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
|
||||
|
||||
AddLP(lp);
|
||||
|
||||
// Set `node` once, and ensure subsequent nodes match - we assume there
|
||||
// is only one NUMA node per cluster.
|
||||
const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
|
||||
if (is_first_lp) {
|
||||
is_first_lp = false;
|
||||
node_ = lp_node;
|
||||
} else {
|
||||
static bool warned = false;
|
||||
if (lp_node != node_ && !warned) {
|
||||
warned = true;
|
||||
fprintf(stderr,
|
||||
"WARNING: lp %zu on node %zu != cluster node %zu.\n", lp,
|
||||
lp_node, node_);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
const hwy::Topology::Cluster& tcluster);
|
||||
|
||||
// For SortByDescendingSize.
|
||||
size_t Size() const { return num_workers_; }
|
||||
|
|
@ -221,53 +145,15 @@ class BoundedTopology {
|
|||
return package.clusters[cluster_idx];
|
||||
}
|
||||
|
||||
// Returns total number of cluster workers, for deciding whether to pin.
|
||||
size_t TotalWorkers() const {
|
||||
size_t total_workers = 0;
|
||||
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
|
||||
const size_t num_clusters = NumClusters(package_idx);
|
||||
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||
total_workers += GetCluster(package_idx, cluster_idx).Size();
|
||||
}
|
||||
}
|
||||
return total_workers;
|
||||
}
|
||||
|
||||
private:
|
||||
// Sort T := packages/clusters by descending 'size' so that users who only use
|
||||
// one Group get the largest.
|
||||
template <class T>
|
||||
static void SortByDescendingSize(std::vector<T>& groups) {
|
||||
std::sort(groups.begin(), groups.end(),
|
||||
[](const T& a, const T& b) { return a.Size() > b.Size(); });
|
||||
}
|
||||
|
||||
struct Package {
|
||||
// Topology is unknown, rely on OS affinity and user-specified slice.
|
||||
Package(const LPS& enabled_lps, BoundedSlice lp_slice) {
|
||||
clusters.push_back(Cluster(enabled_lps, lp_slice));
|
||||
}
|
||||
|
||||
// NOTE: caller is responsible for checking whether `clusters` is empty.
|
||||
Package(const LPS& enabled_lps, const hwy::Topology& topology,
|
||||
size_t package_idx, BoundedSlice cluster_slice) {
|
||||
const hwy::Topology::Package& tpackage = topology.packages[package_idx];
|
||||
// Populate `clusters` with the subset of clusters in `cluster_slice` that
|
||||
// have any enabled LPs. If `clusters` remains empty, the caller will
|
||||
// skip this `Package`.
|
||||
clusters.reserve(cluster_slice.Num(tpackage.clusters.size()));
|
||||
cluster_slice.Foreach(
|
||||
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
|
||||
const hwy::Topology::Cluster& tcluster =
|
||||
tpackage.clusters[cluster_idx];
|
||||
Cluster cluster(enabled_lps, topology.lps, tcluster);
|
||||
// Skip if empty, i.e. too few `enabled_lps`.
|
||||
if (HWY_LIKELY(cluster.Size() != 0)) {
|
||||
clusters.push_back(std::move(cluster));
|
||||
}
|
||||
});
|
||||
SortByDescendingSize(clusters);
|
||||
}
|
||||
size_t package_idx, BoundedSlice cluster_slice);
|
||||
|
||||
// For SortByDescendingSize.
|
||||
size_t Size() const { return clusters.size(); }
|
||||
|
|
@ -275,48 +161,9 @@ class BoundedTopology {
|
|||
std::vector<Cluster> clusters;
|
||||
}; // Package
|
||||
|
||||
#if !GEMMA_DISABLE_TOPOLOGY
|
||||
// Main part of ctor, called when topology is known.
|
||||
void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice,
|
||||
BoundedSlice cluster_slice) {
|
||||
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
|
||||
package_slice.Foreach(
|
||||
"package", topology_.packages.size(), [&](size_t package_idx) {
|
||||
Package package(enabled_lps, topology_, package_idx, cluster_slice);
|
||||
// Skip if empty, i.e. too few `enabled_lps`.
|
||||
if (HWY_LIKELY(!package.clusters.empty())) {
|
||||
packages_.push_back(std::move(package));
|
||||
}
|
||||
});
|
||||
if (NumPackages() == 0) return;
|
||||
SortByDescendingSize(packages_);
|
||||
|
||||
const hwy::Topology::Package& tpackage0 = topology_.packages[0];
|
||||
HWY_ASSERT(!tpackage0.clusters.empty());
|
||||
const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0];
|
||||
// GetCluster(0, 0) is valid because only non-empty Packages were kept.
|
||||
snprintf(topology_string_, sizeof(topology_string_),
|
||||
"%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(),
|
||||
tpackage0.clusters.size(), tcluster0.lps.Count(), packages_.size(),
|
||||
NumClusters(0), GetCluster(0, 0).Size());
|
||||
|
||||
// Remember NUMA nodes of *enabled* LPs.
|
||||
enabled_lps.Foreach([&](size_t lp) {
|
||||
nodes_.Set(static_cast<size_t>(topology_.lps[lp].node));
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice) {
|
||||
packages_.push_back(Package(enabled_lps, lp_slice));
|
||||
|
||||
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
|
||||
GetCluster(0, 0).Size());
|
||||
|
||||
// Assume a single NUMA node.
|
||||
nodes_.Set(0);
|
||||
HWY_ASSERT(NumNodes() == 1);
|
||||
}
|
||||
BoundedSlice cluster_slice);
|
||||
void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice);
|
||||
|
||||
#if !GEMMA_DISABLE_TOPOLOGY
|
||||
hwy::Topology topology_;
|
||||
|
|
@ -360,51 +207,32 @@ class NestedPools {
|
|||
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
|
||||
// only impose upper bounds on the number of detected packages and clusters
|
||||
// rather than defining the actual number of threads.
|
||||
//
|
||||
// `pin` is 0 or 1 to force disable/enable, or -1 to choose automatically.
|
||||
NestedPools(size_t max_threads, int pin = -1,
|
||||
NestedPools(size_t max_threads, Tristate pin = Tristate::kDefault,
|
||||
BoundedSlice package_slice = BoundedSlice(),
|
||||
BoundedSlice cluster_slice = BoundedSlice(),
|
||||
BoundedSlice lp_slice = BoundedSlice())
|
||||
: topology_(package_slice, cluster_slice, lp_slice) {
|
||||
if (pin == -1) pin = topology_.TotalWorkers() >= 12;
|
||||
BoundedSlice lp_slice = BoundedSlice());
|
||||
|
||||
packages_.resize(topology_.NumPackages());
|
||||
all_packages_ = MakePool(packages_.size());
|
||||
const size_t max_workers_per_package = max_threads / packages_.size();
|
||||
// Each worker in all_packages_, including the main thread, will be the
|
||||
// calling thread of an all_clusters->Run, and hence pinned to one of the
|
||||
// `cluster.lps` if `pin`.
|
||||
all_packages_->Run(
|
||||
0, all_packages_->NumWorkers(),
|
||||
[&](uint64_t package_idx, size_t thread) {
|
||||
HWY_ASSERT(package_idx == thread); // each thread has one task
|
||||
packages_[package_idx] = Package(
|
||||
topology_, package_idx, max_workers_per_package, pin, lp_slice);
|
||||
});
|
||||
|
||||
// For mapping package/cluster/thread to noncontiguous TLS indices, in case
|
||||
// cluster/thread counts differ.
|
||||
HWY_ASSERT(!packages_.empty() && packages_.size() <= 16);
|
||||
for (const Package& p : packages_) {
|
||||
max_clusters_per_package_ =
|
||||
HWY_MAX(max_clusters_per_package_, p.NumClusters());
|
||||
max_workers_per_cluster_ =
|
||||
HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster());
|
||||
// Subject to `use_spinning`, enables spin waits with the goal of reducing the
|
||||
// latency of barrier synchronization. We only spin during Generate to avoid
|
||||
// wasting energy during long waits. If `use_spinning` is kDefault, we first
|
||||
// set it to kTrue or kFalse based on a heuristic.
|
||||
void MaybeStartSpinning(Tristate& use_spinning) {
|
||||
if (HWY_UNLIKELY(use_spinning == Tristate::kDefault)) {
|
||||
// The default is to only spin when pinning was enabled and supported by
|
||||
// the OS. Unless spin-waits have near-exclusive use of a core, the tail
|
||||
// latency can be higher than blocking waits.
|
||||
use_spinning = all_pinned_ ? Tristate::kTrue : Tristate::kFalse;
|
||||
}
|
||||
if (use_spinning == Tristate::kTrue) {
|
||||
SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
}
|
||||
}
|
||||
void MaybeStopSpinning(const Tristate use_spinning) {
|
||||
HWY_DASSERT(use_spinning != Tristate::kDefault); // see MaybeStartSpinning
|
||||
if (use_spinning == Tristate::kTrue) {
|
||||
SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
HWY_ASSERT(max_clusters_per_package_ >= 1);
|
||||
HWY_ASSERT(max_clusters_per_package_ <= 64);
|
||||
HWY_ASSERT(max_workers_per_cluster_ >= 1);
|
||||
HWY_ASSERT(max_workers_per_cluster_ <= 256);
|
||||
}
|
||||
|
||||
// Spinning reduces the latency of barrier synchronization, but wastes lots
|
||||
// of energy for long waits, so only do it during generation. Spinning might
|
||||
// also be unsafe in virtualized environments because we require threads to
|
||||
// be running on their own core and thus responsive to the barrier
|
||||
// synchronization.
|
||||
void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); }
|
||||
void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); }
|
||||
|
||||
hwy::ThreadPool& AllPackages() { return *all_packages_; }
|
||||
hwy::ThreadPool& AllClusters(size_t package_idx) {
|
||||
|
|
@ -435,7 +263,9 @@ class NestedPools {
|
|||
|
||||
// For Allocator
|
||||
const BoundedTopology& Topology() const { return topology_; }
|
||||
// For ShowConfig
|
||||
const char* TopologyString() const { return topology_.TopologyString(); }
|
||||
const char* PinString() const { return pin_string_; }
|
||||
|
||||
// Returns a single pool on the first package: either one thread per cluster
|
||||
// if there is more than one, which maximizes available memory bandwidth, or
|
||||
|
|
@ -449,56 +279,14 @@ class NestedPools {
|
|||
}
|
||||
|
||||
private:
|
||||
// `max_or_zero` == 0 means no limit.
|
||||
static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
||||
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
||||
}
|
||||
|
||||
// We want vectors of hwy::ThreadPool, which is unfortunately not movable,
|
||||
// hence we wrap them in unique_ptr.
|
||||
using PoolPtr = std::unique_ptr<hwy::ThreadPool>;
|
||||
|
||||
static PoolPtr MakePool(size_t num_workers) {
|
||||
// `ThreadPool` expects the number of threads to create, which is one less
|
||||
// than the number of workers, but avoid underflow if zero.
|
||||
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
|
||||
return std::make_unique<hwy::ThreadPool>(num_threads);
|
||||
}
|
||||
class Pinning;
|
||||
|
||||
class Package {
|
||||
public:
|
||||
Package() = default; // for vector
|
||||
Package(const BoundedTopology& topology, size_t package_idx,
|
||||
size_t max_workers_per_package, int pin, BoundedSlice lp_slice) {
|
||||
// Pre-allocate because elements are set concurrently.
|
||||
clusters_.resize(topology.NumClusters(package_idx));
|
||||
const size_t max_workers_per_cluster =
|
||||
max_workers_per_package / clusters_.size();
|
||||
|
||||
all_clusters_ = MakePool(clusters_.size());
|
||||
// Parallel so we also pin the calling worker in `all_clusters` to
|
||||
// `cluster.lps`.
|
||||
all_clusters_->Run(
|
||||
0, all_clusters_->NumWorkers(),
|
||||
[&](size_t cluster_idx, size_t thread) {
|
||||
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
||||
const BoundedTopology::Cluster& cluster =
|
||||
topology.GetCluster(package_idx, cluster_idx);
|
||||
clusters_[cluster_idx] =
|
||||
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
|
||||
if (HWY_LIKELY(pin)) {
|
||||
// Pin threads AND the calling thread from `all_clusters` to lps.
|
||||
const std::vector<size_t> lps = cluster.LPVector();
|
||||
HWY_ASSERT(clusters_[cluster_idx]->NumWorkers() <= lps.size());
|
||||
clusters_[cluster_idx]->Run(
|
||||
0, clusters_[cluster_idx]->NumWorkers(),
|
||||
[&lps](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each worker has one task
|
||||
hwy::PinThreadToLogicalProcessor(lps[task]);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
size_t max_workers_per_package, Pinning& pinning,
|
||||
BoundedSlice lp_slice);
|
||||
|
||||
size_t NumClusters() const { return clusters_.size(); }
|
||||
size_t MaxWorkersPerCluster() const {
|
||||
|
|
@ -536,6 +324,8 @@ class NestedPools {
|
|||
}
|
||||
|
||||
BoundedTopology topology_;
|
||||
bool all_pinned_;
|
||||
const char* pin_string_;
|
||||
|
||||
std::vector<Package> packages_;
|
||||
PoolPtr all_packages_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue