Cleanup: remove unused kCyclic, remove 2 suffix

Also remove now unused allocator arg and fix warnings (cast, struct/class mismatch)

PiperOrigin-RevId: 758098495
This commit is contained in:
Jan Wassenberg 2025-05-13 01:05:42 -07:00 committed by Copybara-Service
parent ba21e3beb4
commit d538a6d6c6
18 changed files with 114 additions and 213 deletions

View File

@ -42,8 +42,6 @@ static ModelConfig ConfigNoSSM() {
return config; return config;
} }
static ModelConfig ConfigBaseGemmaV1() { return ConfigNoSSM(); }
static ModelConfig ConfigBaseGemmaV2() { static ModelConfig ConfigBaseGemmaV2() {
ModelConfig config = ConfigNoSSM(); ModelConfig config = ConfigNoSSM();
config.att_cap = 50.0f; config.att_cap = 50.0f;

View File

@ -31,7 +31,6 @@
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "paligemma/image.h" #include "paligemma/image.h"
#include "util/allocator.h"
#include "util/mat.h" #include "util/mat.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
@ -260,8 +259,7 @@ class GemmaAttention {
const size_t w1_rows = heads * layer_config_.QStride(); const size_t w1_rows = heads * layer_config_.QStride();
w_q1.ShrinkRows(w1_rows); w_q1.ShrinkRows(w1_rows);
MatMul(activations_.pre_att_rms_out, w_q1, MatMul(activations_.pre_att_rms_out, w_q1,
/*add=*/nullptr, *activations_.env, /*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q));
RowPtrFromMat(allocator_, activations_.q));
if (is_mha_) { if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
@ -284,7 +282,7 @@ class GemmaAttention {
const size_t kv_ofs = const size_t kv_ofs =
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_; queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols); RowPtrF kv_rows(kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_); kv_rows.SetStride(cache_pos_size_);
MatMul(activations_.pre_att_rms_out, w_q2, MatMul(activations_.pre_att_rms_out, w_q2,
/*add=*/nullptr, *activations_.env, kv_rows); /*add=*/nullptr, *activations_.env, kv_rows);
@ -490,7 +488,7 @@ class GemmaAttention {
? layer_weights_.attention_output_biases.PackedScale1() ? layer_weights_.attention_output_biases.PackedScale1()
: nullptr; : nullptr;
MatMul(activations_.att_out, layer_weights_.att_weights, add, MatMul(activations_.att_out, layer_weights_.att_weights, add,
*activations_.env, RowPtrFromMat(allocator_, activations_.att_sums)); *activations_.env, RowPtrFromMat(activations_.att_sums));
} }
public: public:
@ -556,7 +554,6 @@ class GemmaAttention {
layer_weights_(*layer_weights), layer_weights_(*layer_weights),
div_seq_len_(div_seq_len), div_seq_len_(div_seq_len),
kv_caches_(kv_caches), kv_caches_(kv_caches),
allocator_(ctx.allocator),
pool_(ctx.pools.Pool(0)) { pool_(ctx.pools.Pool(0)) {
HWY_DASSERT(num_queries_ <= kv_caches_.size()); HWY_DASSERT(num_queries_ <= kv_caches_.size());
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0, HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
@ -586,7 +583,6 @@ class GemmaAttention {
const LayerWeightsPtrs<T>& layer_weights_; const LayerWeightsPtrs<T>& layer_weights_;
const hwy::Divisor& div_seq_len_; const hwy::Divisor& div_seq_len_;
const KVCaches& kv_caches_; const KVCaches& kv_caches_;
const Allocator& allocator_;
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
}; };
@ -631,7 +627,7 @@ class VitAttention {
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w, MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
RowPtrFromMat(allocator_, qkv)); RowPtrFromMat(qkv));
} }
// TODO(philculliton): transition fully to MatMul. // TODO(philculliton): transition fully to MatMul.
@ -671,7 +667,7 @@ class VitAttention {
}); });
// this produces C, a (num_tokens_, seq_len) matrix of dot products // this produces C, a (num_tokens_, seq_len) matrix of dot products
MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(allocator_, C)); MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(C));
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Row(task); float* HWY_RESTRICT c = C.Row(task);
@ -737,7 +733,7 @@ class VitAttention {
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads. // matmul output is the sum over heads.
auto att_sums = RowPtrFromMat(allocator_, activations_.att_sums); auto att_sums = RowPtrFromMat(activations_.att_sums);
MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias, MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
*activations_.env, att_sums); *activations_.env, att_sums);
} }
@ -750,7 +746,6 @@ class VitAttention {
activations_(activations), activations_(activations),
layer_weights_(*layer_weights), layer_weights_(*layer_weights),
layer_config_(layer_weights->layer_config), layer_config_(layer_weights->layer_config),
allocator_(activations.env->ctx.allocator),
pool_(activations.env->ctx.pools.Pool(0)) {} pool_(activations.env->ctx.pools.Pool(0)) {}
HWY_INLINE void operator()() { HWY_INLINE void operator()() {
@ -769,7 +764,6 @@ class VitAttention {
Activations& activations_; Activations& activations_;
const LayerWeightsPtrs<T>& layer_weights_; const LayerWeightsPtrs<T>& layer_weights_;
const LayerConfig& layer_config_; const LayerConfig& layer_config_;
const Allocator& allocator_;
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
}; };
@ -832,10 +826,9 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations. // Define slightly more readable names for the weights and activations.
const Allocator& allocator = activations.env->ctx.allocator; auto hidden_activations = RowPtrFromMat(activations.C1);
auto hidden_activations = RowPtrFromMat(allocator, activations.C1); auto multiplier = RowPtrFromMat(activations.C2);
auto multiplier = RowPtrFromMat(allocator, activations.C2); auto ffw_out = RowPtrFromMat(activations.ffw_out);
auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
using WeightT = typename decltype(layer_weights->gating_einsum_w)::T; using WeightT = typename decltype(layer_weights->gating_einsum_w)::T;
@ -881,22 +874,16 @@ HWY_NOINLINE void FFWVit(Activations& activations,
const float* output_bias = const float* output_bias =
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr; add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations.
const Allocator& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromMat(allocator, activations.C1);
auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
// Compute the hidden layer activations. // Compute the hidden layer activations.
MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1, MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1,
*activations.env, hidden_activations); *activations.env, RowPtrFromMat(activations.C1));
// Activation (Gelu), store in act. // Activation (Gelu), store in C1.
RowPtrF multiplier = RowPtrF(allocator, nullptr, 0);
ActivationBatched(layer_weights->layer_config.activation, activations.C1); ActivationBatched(layer_weights->layer_config.activation, activations.C1);
// Hidden layer -> output layer. // Hidden layer -> output layer.
MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias, MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias,
*activations.env, ffw_out); *activations.env, RowPtrFromMat(activations.ffw_out));
} }
// `batch_idx` indicates which row of `x` to write to. // `batch_idx` indicates which row of `x` to write to.
@ -932,11 +919,10 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
} }
const size_t model_dim = weights.weights_config.model_dim; const size_t model_dim = weights.weights_config.model_dim;
const size_t vocab_size = weights.weights_config.vocab_size;
const float emb_scaling = EmbeddingScaling(model_dim); const float emb_scaling = EmbeddingScaling(model_dim);
HWY_DASSERT(token >= 0); HWY_DASSERT(token >= 0);
HWY_DASSERT(token < static_cast<int>(vocab_size)); HWY_DASSERT(token < static_cast<int>(weights.weights_config.vocab_size));
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
// Using `Stride` to compute the offset works for both NUQ (because we use an // Using `Stride` to compute the offset works for both NUQ (because we use an
@ -1263,7 +1249,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
// Apply head embedding into image_tokens of size of the LLM kModelDim. // Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul(activations.x, weights.vit_img_head_kernel, MatMul(activations.x, weights.vit_img_head_kernel,
weights.vit_img_head_bias.PackedScale1(), *activations.env, weights.vit_img_head_bias.PackedScale1(), *activations.env,
RowPtrFromMat(activations.env->ctx.allocator, image_tokens)); RowPtrFromMat(image_tokens));
} }
// Generates one token for each query. `queries_token` is the previous token // Generates one token for each query. `queries_token` is the previous token
@ -1403,7 +1389,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
// Compute logits from last layer activations. // Compute logits from last layer activations.
MatMul(activations.x, weights.embedder_input_embedding, MatMul(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, *activations.env, /*add=*/nullptr, *activations.env,
RowPtrFromMat(activations.env->ctx.allocator, activations.logits)); RowPtrFromMat(activations.logits));
} }
PROFILER_ZONE("Gen.Softcap+Sample+Stream"); PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {

View File

@ -107,7 +107,7 @@ using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
// - per-query position within the tokens sequence // - per-query position within the tokens sequence
// - layer index (or -1 for post-norm output) // - layer index (or -1 for post-norm output)
// - activations // - activations
class Activations; struct Activations;
using ActivationsObserverFunc = using ActivationsObserverFunc =
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>; std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;

View File

@ -25,7 +25,6 @@
namespace gcpp { namespace gcpp {
struct KVCache { struct KVCache {
KVCache() = default; // for std::vector.
KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size); KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size);
// Returns a deep copy of the KVCache. // Returns a deep copy of the KVCache.

View File

@ -115,7 +115,7 @@ class FilePosix : public File {
#endif #endif
return MapPtr(static_cast<const uint8_t*>(mapping), return MapPtr(static_cast<const uint8_t*>(mapping),
DeleterFunc2([mapping_size](void* ptr) { DeleterFunc([mapping_size](void* ptr) {
HWY_ASSERT(munmap(ptr, mapping_size) == 0); HWY_ASSERT(munmap(ptr, mapping_size) == 0);
})); }));
} }

View File

@ -33,7 +33,7 @@ namespace gcpp {
// prefer to define Exists inline because there are multiple io*.cc files. // prefer to define Exists inline because there are multiple io*.cc files.
struct Path; struct Path;
using MapPtr = AlignedPtr2<const uint8_t[]>; using MapPtr = AlignedPtr<const uint8_t[]>;
// Abstract base class enables multiple I/O backends in the same binary. // Abstract base class enables multiple I/O backends in the same binary.
class File { class File {

View File

@ -108,7 +108,7 @@ class FileWin : public File {
void* ptr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); void* ptr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
if (!ptr) return MapPtr(); if (!ptr) return MapPtr();
return MapPtr(static_cast<const uint8_t*>(ptr), return MapPtr(static_cast<const uint8_t*>(ptr),
DeleterFunc2([hMapping](void* ptr) { DeleterFunc([hMapping](void* ptr) {
HWY_ASSERT(UnmapViewOfFile(ptr)); HWY_ASSERT(UnmapViewOfFile(ptr));
HWY_ASSERT(CloseHandle(hMapping)); HWY_ASSERT(CloseHandle(hMapping));
})); }));

View File

@ -105,7 +105,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool); MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
const float* add_row = add ? add_storage.PackedScale1() : nullptr; const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch); const RowPtr<TC> C = RowPtrFromMat(c_batch);
// Fewer reps for large batch sizes, which take longer. // Fewer reps for large batch sizes, which take longer.
const size_t num_samples = M < 32 ? 20 : 12; const size_t num_samples = M < 32 ? 20 : 12;

View File

@ -1140,7 +1140,7 @@ void TestAllDot() {
for (size_t variant = 0; variant < kVariants; ++variant) { for (size_t variant = 0; variant < kVariants; ++variant) {
constexpr size_t kTimeReps = hn::AdjustedReps(10); constexpr size_t kTimeReps = hn::AdjustedReps(10);
std::array<double, kTimeReps> elapsed; std::array<double, kTimeReps> elapsed;
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) { for (size_t time_rep = 0; time_rep < kTimeReps; ++time_rep) {
const double start = hwy::platform::Now(); const double start = hwy::platform::Now();
dots[variant] += dots[variant] +=
CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num); CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);

View File

@ -864,7 +864,7 @@ class MMPerPackage {
: args_(args), : args_(args),
pkg_idx_(pkg_idx), pkg_idx_(pkg_idx),
// May be overwritten with a view of A, if already BF16. // May be overwritten with a view of A, if already BF16.
A_(args_.env->storage.A(args.env->ctx.allocator, pkg_idx, A.Extents())), A_(args_.env->storage.A(pkg_idx, A.Extents())),
range_np_(range_np), range_np_(range_np),
mr_(config.MR()), mr_(config.MR()),
ranges_mc_(config.RangesOfMC(A.Extents().rows)), ranges_mc_(config.RangesOfMC(A.Extents().rows)),
@ -905,9 +905,8 @@ class MMPerPackage {
// Compute size of per-worker storage for `kNR` row ranges of B. Stack // Compute size of per-worker storage for `kNR` row ranges of B. Stack
// allocation avoids passing a worker index. // allocation avoids passing a worker index.
static constexpr size_t B_stride_max_ = static constexpr size_t B_stride_max_ =
MaxStrideForCyclicOffsets<BF16>(MMStorage::kMaxKC); MMStorage::kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16);
static constexpr size_t B_storage_max_ = static constexpr size_t B_storage_max_ = kNR * B_stride_max_;
kNR * B_stride_max_ + Allocator::MaxQuantum<BF16>();
// Granularity of `ForNP`. B rows produce C columns, so we // Granularity of `ForNP`. B rows produce C columns, so we
// want a multiple of the line size to prevent false sharing. // want a multiple of the line size to prevent false sharing.
@ -928,15 +927,14 @@ class MMPerPackage {
const size_t K = range_K.Num(); const size_t K = range_K.Num();
const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K); const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K);
const size_t B_stride = const size_t B_stride =
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>()); Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
// Similar to `loop_nc` below, but here we hoisted `A_view`. // Similar to `loop_nc` below, but here we hoisted `A_view`.
args_.env->parallel.ForNP( args_.env->parallel.ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR { [&](const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K, const RowPtrBF B_view(B_storage, K, B_stride);
B_stride);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) { row_b += kNR) {
@ -971,8 +969,8 @@ class MMPerPackage {
const size_t kc = range_kc.Num(); const size_t kc = range_kc.Num();
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc); const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
const RowPtrBF B_view( const RowPtrBF B_view(
args_.env->ctx.allocator, B_storage, kc, B_storage, kc,
StrideForCyclicOffsets(kc, args_.env->ctx.allocator.Quantum<BF16>())); Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) { row_b += kNR) {
@ -1028,7 +1026,7 @@ class MMPerPackage {
const IndexRange& range_K = ranges_kc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0);
const size_t K = range_K.Num(); const size_t K = range_K.Num();
const size_t B_stride = const size_t B_stride =
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>()); Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
// Sequential loop over NC/MC/KC, similar to `loop_nc` below // Sequential loop over NC/MC/KC, similar to `loop_nc` below
// except for the profiler strings and `out_tag`. // except for the profiler strings and `out_tag`.
@ -1037,8 +1035,7 @@ class MMPerPackage {
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K); const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K, const RowPtrBF B_view(B_storage, K, B_stride);
B_stride);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) { row_b += kNR) {
@ -1064,8 +1061,8 @@ class MMPerPackage {
zone.MaybeEnter("MM.NT_MT_K", args_); zone.MaybeEnter("MM.NT_MT_K", args_);
const size_t kc_max = ranges_kc_.TaskSize(); const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC); HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
const size_t B_stride = StrideForCyclicOffsets( const size_t B_stride =
kc_max, args_.env->ctx.allocator.Quantum<BF16>()); Stride(MatPadding::kOdd, kc_max, sizeof(BF16), line_bytes_);
// Sequential loop over NC/MC/KC, for when the M/N loops are // Sequential loop over NC/MC/KC, for when the M/N loops are
// already parallel. This is B3A2C0 in MOMMS terminology: we read // already parallel. This is B3A2C0 in MOMMS terminology: we read
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
@ -1091,8 +1088,7 @@ class MMPerPackage {
ranges_mc_, ranges_nc_, pkg_idx_, ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, kc_max, const RowPtrBF B_view(B_storage, kc_max, B_stride);
B_stride);
// Peel off the first iteration of the kc loop: avoid // Peel off the first iteration of the kc loop: avoid
// zero-initializing `partial` by writing into it. // zero-initializing `partial` by writing into it.
@ -1172,13 +1168,12 @@ class MMPerPackage {
// Autotuning wrapper for `DoDecompressA`. // Autotuning wrapper for `DoDecompressA`.
template <typename TA> template <typename TA>
HWY_INLINE RowPtrBF DecompressA(const MatPtrT<TA>& A) const { HWY_INLINE RowPtrBF DecompressA(const MatPtrT<TA>& A) const {
const Allocator& allocator = args_.env->ctx.allocator;
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_]; MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
// If already BF16, maybe return a view: // If already BF16, maybe return a view:
if constexpr (hwy::IsSame<TA, BF16>()) { if constexpr (hwy::IsSame<TA, BF16>()) {
// Only if no zero-padding required. // Only if no zero-padding required.
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>()); const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(allocator, A); if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(A);
} }
if (HWY_LIKELY(autotune.Best())) { if (HWY_LIKELY(autotune.Best())) {

View File

@ -217,8 +217,7 @@ class MMStorage {
: partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), : partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
MatPadding::kOdd), MatPadding::kOdd),
// Same stride independent of the actual C.Cols() so we can pre-bind. // Same stride independent of the actual C.Cols() so we can pre-bind.
partial_(allocator, partial_storage_.Row(0), kMaxN, partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
partial_storage_.Stride()) {
// Per-package allocation so each can decompress A into its own copy. // Per-package allocation so each can decompress A into its own copy.
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>( pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
@ -240,12 +239,11 @@ class MMStorage {
} }
// Returns per-package matrix view. // Returns per-package matrix view.
RowPtrBF A(const Allocator& allocator, size_t pkg_idx, RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK); HWY_DASSERT(extents.cols <= kMaxK);
return RowPtrBF(allocator, const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)), return RowPtrBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)), extents.cols,
extents.cols, pkg_A_[pkg_idx]->Stride()); pkg_A_[pkg_idx]->Stride());
} }
RowPtrD Partial() const { return partial_; } RowPtrD Partial() const { return partial_; }

View File

@ -205,7 +205,6 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
template <typename TA, typename TB = TA, typename TC = float> template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env, int line) { MatMulEnv& env, int line) {
const Allocator& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool(); hwy::ThreadPool& pool = env.ctx.pools.Pool();
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(), rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
@ -229,8 +228,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
add_storage.SetScale(1.0f); add_storage.SetScale(1.0f);
const float* add_row = add ? add_storage.PackedScale1() : nullptr; const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C_slow = RowPtrFromMat(allocator, c_slow_batch); const RowPtr<TC> C_slow = RowPtrFromMat(c_slow_batch);
const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch); const RowPtr<TC> C = RowPtrFromMat(c_batch);
MatMulSlow(a, b_trans, add_row, env, C_slow); MatMulSlow(a, b_trans, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths. // A few reps to get coverage of the various autotuned code paths.

View File

@ -507,7 +507,7 @@ void TestLayerNormSimple() {
const size_t kSize = 52; const size_t kSize = 52;
std::vector<float> values(kSize); std::vector<float> values(kSize);
// Alternating 1.0/-1.0, so mean=0.0, var=1.0, rsqrt(var+epsilon)=0.9999995 // Alternating 1.0/-1.0, so mean=0.0, var=1.0, rsqrt(var+epsilon)=0.9999995
for (int i = 0; i < kSize; ++i) { for (size_t i = 0; i < kSize; ++i) {
values[i] = (i % 2 == 0) ? 1.0f : -1.0f; values[i] = (i % 2 == 0) ? 1.0f : -1.0f;
} }
std::vector<float> scale(kSize, 1.2f); std::vector<float> scale(kSize, 1.2f);

View File

@ -132,7 +132,11 @@ size_t DetectTotalMiB(size_t page_bytes) {
Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) { Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
line_bytes_ = DetectLineBytes(); line_bytes_ = DetectLineBytes();
// Ensure MaxLineBytes() is an upper bound.
HWY_ASSERT(MaxLineBytes() >= LineBytes());
vector_bytes_ = hwy::VectorBytes(); vector_bytes_ = hwy::VectorBytes();
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
base_page_bytes_ = DetectPageSize(); base_page_bytes_ = DetectPageSize();
quantum_bytes_ = step_bytes_; // may overwrite below quantum_bytes_ = step_bytes_; // may overwrite below
@ -165,8 +169,6 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
// Ensure pages meet the alignment requirements of `AllocBytes`. // Ensure pages meet the alignment requirements of `AllocBytes`.
HWY_ASSERT(base_page_bytes_ >= quantum_bytes_); HWY_ASSERT(base_page_bytes_ >= quantum_bytes_);
quantum_bytes_ = base_page_bytes_; quantum_bytes_ = base_page_bytes_;
// Ensure MaxQuantum() is an upper bound.
HWY_ASSERT(MaxQuantum<uint8_t>() >= Quantum<uint8_t>());
should_bind_ = true; should_bind_ = true;
} else { } else {
HWY_WARN( HWY_WARN(
@ -175,9 +177,6 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
} }
} }
} }
HWY_DASSERT(quantum_bytes_ % step_bytes_ == 0);
quantum_step_mask_ = quantum_bytes_ / step_bytes_ - 1;
} }
size_t Allocator::FreeMiB() const { size_t Allocator::FreeMiB() const {
@ -201,7 +200,7 @@ size_t Allocator::FreeMiB() const {
#endif #endif
} }
AlignedPtr2<uint8_t[]> Allocator::AllocBytes(size_t bytes) const { AlignedPtr<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
// If we are not binding, the Highway allocator is cheaper than `mmap`, and // If we are not binding, the Highway allocator is cheaper than `mmap`, and
// defends against 2K aliasing. // defends against 2K aliasing.
if (!should_bind_) { if (!should_bind_) {
@ -217,9 +216,8 @@ AlignedPtr2<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
// alignment scheme in aligned_allocator.cc and does not work for // alignment scheme in aligned_allocator.cc and does not work for
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway // already-aligned pointers as returned by `mmap`, hence we wrap the Highway
// pointer in our own deleter. // pointer in our own deleter.
return AlignedPtr2<uint8_t[]>(p.release(), DeleterFunc2([](void* ptr) { return AlignedPtr<uint8_t[]>(p.release(), DeleterFunc([](void* ptr) {
hwy::FreeAlignedBytes(ptr, nullptr, hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
nullptr);
})); }));
} }
@ -234,17 +232,16 @@ AlignedPtr2<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
const int fd = -1; const int fd = -1;
void* p = mmap(0, bytes, prot, flags, fd, off_t{0}); void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
if (p == MAP_FAILED) p = nullptr; if (p == MAP_FAILED) p = nullptr;
return AlignedPtr2<uint8_t[]>(static_cast<uint8_t*>(p), return AlignedPtr<uint8_t[]>(
DeleterFunc2([bytes](void* ptr) { static_cast<uint8_t*>(p),
HWY_ASSERT(munmap(ptr, bytes) == 0); DeleterFunc([bytes](void* ptr) { HWY_ASSERT(munmap(ptr, bytes) == 0); }));
}));
#elif HWY_OS_WIN #elif HWY_OS_WIN
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_); const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
return AlignedPtr2<uint8_t[]>( return AlignedPtr<uint8_t[]>(
static_cast<uint8_t*>(_aligned_malloc(bytes, alignment)), static_cast<uint8_t*>(_aligned_malloc(bytes, alignment)),
DeleterFunc2([](void* ptr) { _aligned_free(ptr); })); DeleterFunc([](void* ptr) { _aligned_free(ptr); }));
#else #else
return AlignedPtr2<uint8_t[]>(nullptr, DeleterFunc2()); return AlignedPtr<uint8_t[]>(nullptr, DeleterFunc());
#endif #endif
} }

View File

@ -34,13 +34,13 @@ namespace gcpp {
// Custom deleter for types without a dtor, but where the deallocation requires // Custom deleter for types without a dtor, but where the deallocation requires
// state, e.g. a lambda with *by-value* capture. // state, e.g. a lambda with *by-value* capture.
class DeleterFunc2 { class DeleterFunc {
public: public:
// `MatOwnerT` requires this to be default-constructible. // `MatOwnerT` requires this to be default-constructible.
DeleterFunc2() = default; DeleterFunc() = default;
template <class Closure> template <class Closure>
DeleterFunc2(const Closure& free_closure) : free_func_(free_closure) {} DeleterFunc(const Closure& free_closure) : free_func_(free_closure) {}
template <typename T> template <typename T>
void operator()(T* p) const { void operator()(T* p) const {
@ -52,10 +52,10 @@ class DeleterFunc2 {
}; };
// Wrapper that also calls the destructor for each element being deallocated. // Wrapper that also calls the destructor for each element being deallocated.
class DeleterDtor2 { class DeleterDtor {
public: public:
DeleterDtor2() {} DeleterDtor() {}
DeleterDtor2(size_t num, DeleterFunc2 free) : num_(num), free_(free) {} DeleterDtor(size_t num, DeleterFunc free) : num_(num), free_(free) {}
template <typename T> template <typename T>
void operator()(T* p) const { void operator()(T* p) const {
@ -67,15 +67,15 @@ class DeleterDtor2 {
private: private:
size_t num_; size_t num_;
DeleterFunc2 free_; DeleterFunc free_;
}; };
// Unique (move-only) pointer to aligned POD T, which can be an array or class. // Unique (move-only) pointer to aligned POD T, which can be an array or class.
template <typename T> template <typename T>
using AlignedPtr2 = std::unique_ptr<T, DeleterFunc2>; using AlignedPtr = std::unique_ptr<T, DeleterFunc>;
// Unique (move-only) pointer to an aligned array of non-POD T. // Unique (move-only) pointer to an aligned array of non-POD T.
template <typename T> template <typename T>
using AlignedClassPtr2 = std::unique_ptr<T, DeleterDtor2>; using AlignedClassPtr = std::unique_ptr<T, DeleterDtor>;
// Both allocation, binding, and row accessors depend on the sizes of memory // Both allocation, binding, and row accessors depend on the sizes of memory
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we // pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
@ -90,26 +90,24 @@ class Allocator {
// Bytes per cache line, or a reasonable guess if unknown. Used to choose // Bytes per cache line, or a reasonable guess if unknown. Used to choose
// ranges such that there will be no false sharing. // ranges such that there will be no false sharing.
size_t LineBytes() const { return line_bytes_; } size_t LineBytes() const { return line_bytes_; }
// Upper bound on `LineBytes()`, for stack allocations.
static constexpr size_t MaxLineBytes() { return 256; }
// Bytes per full vector. Used to compute loop steps. // Bytes per full vector. Used to compute loop steps.
size_t VectorBytes() const { return vector_bytes_; } size_t VectorBytes() const { return vector_bytes_; }
// Work granularity that avoids false sharing and partial vectors. // Work granularity that avoids false sharing and partial vectors.
// = HWY_MAX(LineBytes(), VectorBytes()) // = HWY_MAX(LineBytes(), VectorBytes())
size_t StepBytes() const { return step_bytes_; } size_t StepBytes() const { return step_bytes_; }
// File size multiple required for memory mapping. // File size multiple required for memory mapping.
size_t BasePageBytes() const { return base_page_bytes_; } size_t BasePageBytes() const { return base_page_bytes_; }
// Either StepBytes or BasePageBytes if NUMA. // Either StepBytes or BasePageBytes if NUMA.
size_t QuantumBytes() const { return quantum_bytes_; } size_t QuantumBytes() const { return quantum_bytes_; }
template <typename T> template <typename T>
// For rounding down elements to the page size in `BindB/BindC`.
size_t Quantum() const { size_t Quantum() const {
return QuantumBytes() / sizeof(T); return QuantumBytes() / sizeof(T);
} }
// Upper bound on `Quantum()`, for stack allocations.
template <typename T>
static constexpr size_t MaxQuantum() {
return 4096 / sizeof(T);
}
// = QuantumBytes() / StepBytes() - 1
size_t QuantumStepMask() const { return quantum_step_mask_; }
// L1 and L2 are typically per core. // L1 and L2 are typically per core.
size_t L1Bytes() const { return l1_bytes_; } size_t L1Bytes() const { return l1_bytes_; }
@ -123,35 +121,35 @@ class Allocator {
// Returns byte pointer aligned to `QuantumBytes()`, without calling // Returns byte pointer aligned to `QuantumBytes()`, without calling
// constructors nor destructors on deletion. Type-erased so this can be // constructors nor destructors on deletion. Type-erased so this can be
// implemented in `allocator.cc` and called by `MatOwner`. // implemented in `allocator.cc` and called by `MatOwner`.
AlignedPtr2<uint8_t[]> AllocBytes(size_t bytes) const; AlignedPtr<uint8_t[]> AllocBytes(size_t bytes) const;
// Returns pointer aligned to `QuantumBytes()`, without calling constructors // Returns pointer aligned to `QuantumBytes()`, without calling constructors
// nor destructors on deletion. // nor destructors on deletion.
template <typename T> template <typename T>
AlignedPtr2<T[]> Alloc(size_t num) const { AlignedPtr<T[]> Alloc(size_t num) const {
const size_t bytes = num * sizeof(T); const size_t bytes = num * sizeof(T);
// Fail if the `bytes = num * sizeof(T)` computation overflowed. // Fail if the `bytes = num * sizeof(T)` computation overflowed.
HWY_ASSERT(bytes / sizeof(T) == num); HWY_ASSERT(bytes / sizeof(T) == num);
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes); AlignedPtr<uint8_t[]> p8 = AllocBytes(bytes);
return AlignedPtr2<T[]>(HWY_RCAST_ALIGNED(T*, p8.release()), return AlignedPtr<T[]>(HWY_RCAST_ALIGNED(T*, p8.release()),
p8.get_deleter()); p8.get_deleter());
} }
// Same as Alloc, but calls constructor(s) with `args` and the deleter will // Same as Alloc, but calls constructor(s) with `args` and the deleter will
// call destructor(s). // call destructor(s).
template <typename T, class... Args> template <typename T, class... Args>
AlignedClassPtr2<T> AllocClasses(size_t num, Args&&... args) const { AlignedClassPtr<T> AllocClasses(size_t num, Args&&... args) const {
const size_t bytes = num * sizeof(T); const size_t bytes = num * sizeof(T);
// Fail if the `bytes = num * sizeof(T)` computation overflowed. // Fail if the `bytes = num * sizeof(T)` computation overflowed.
HWY_ASSERT(bytes / sizeof(T) == num); HWY_ASSERT(bytes / sizeof(T) == num);
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes); AlignedPtr<uint8_t[]> p8 = AllocBytes(bytes);
T* p = HWY_RCAST_ALIGNED(T*, p8.release()); T* p = HWY_RCAST_ALIGNED(T*, p8.release());
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
new (p + i) T(std::forward<Args>(args)...); new (p + i) T(std::forward<Args>(args)...);
} }
return AlignedClassPtr2<T>(p, DeleterDtor2(num, p8.get_deleter())); return AlignedClassPtr<T>(p, DeleterDtor(num, p8.get_deleter()));
} }
// Returns whether `BindMemory` can/should be called, i.e. we have page-level // Returns whether `BindMemory` can/should be called, i.e. we have page-level
@ -170,7 +168,6 @@ class Allocator {
size_t step_bytes_; size_t step_bytes_;
size_t base_page_bytes_; size_t base_page_bytes_;
size_t quantum_bytes_; size_t quantum_bytes_;
size_t quantum_step_mask_;
size_t l1_bytes_ = 0; size_t l1_bytes_ = 0;
size_t l2_bytes_ = 0; size_t l2_bytes_ = 0;

View File

@ -89,38 +89,31 @@ void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
} }
} }
// Returns `num` rounded up to an odd number of cache lines. This would also size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
// prevent 4K aliasing and is coprime with the cache associativity, which size_t line_bytes) {
// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`.
static size_t RoundUpToOddLines(size_t num, size_t line_bytes,
size_t element_bytes) {
HWY_DASSERT(line_bytes >= 32);
HWY_DASSERT(line_bytes % element_bytes == 0);
const size_t lines = hwy::DivCeil(num * element_bytes, line_bytes);
const size_t padded_num = (lines | 1) * line_bytes / element_bytes;
HWY_DASSERT(padded_num >= num);
return padded_num;
}
static size_t Stride(const Allocator& allocator, const MatPtr& mat,
MatPadding padding) {
switch (padding) { switch (padding) {
case MatPadding::kPacked: case MatPadding::kPacked:
default: default:
return mat.Cols(); return cols;
case MatPadding::kOdd: case MatPadding::kOdd: {
return RoundUpToOddLines(mat.Cols(), allocator.LineBytes(), // Round up to an odd number of cache lines to prevent 4K aliasing and
mat.ElementBytes()); // reduce conflict misses (coprime with the cache associativity).
case MatPadding::kCyclic: HWY_DASSERT(line_bytes >= 32);
return StrideForCyclicOffsets( HWY_DASSERT(line_bytes % element_bytes == 0);
mat.Cols(), allocator.QuantumBytes() / mat.ElementBytes()); const size_t lines = hwy::DivCeil(cols * element_bytes, line_bytes);
const size_t padded_cols = (lines | 1) * line_bytes / element_bytes;
HWY_DASSERT(padded_cols >= cols);
return padded_cols;
}
} }
} }
void MatOwner::AllocateFor(MatPtr& mat, const MatPadding padding) { void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
const bool is_nuq = mat.GetType() == Type::kNUQ; const bool is_nuq = mat.GetType() == Type::kNUQ;
if (is_nuq) padding = MatPadding::kPacked;
const Allocator& allocator = ThreadingContext::Get().allocator; const Allocator& allocator = ThreadingContext::Get().allocator;
const size_t stride = is_nuq ? mat.Cols() : Stride(allocator, mat, padding); const size_t stride =
Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes());
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride; const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding` // `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
// might not be enough, hence add extra. `MatT` is at least one byte, which // might not be enough, hence add extra. `MatT` is at least one byte, which

View File

@ -28,7 +28,7 @@
#include "compression/shared.h" // Type #include "compression/shared.h" // Type
#include "gemma/tensor_info.h" #include "gemma/tensor_info.h"
#include "io/fields.h" #include "io/fields.h"
#include "util/allocator.h" // AlignedPtr2 #include "util/allocator.h" // AlignedPtr
#include "util/basics.h" // Extents2D #include "util/basics.h" // Extents2D
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "hwy/base.h" #include "hwy/base.h"
@ -339,24 +339,6 @@ void ZeroInit(MatPtr& mat);
// F32/F64 only. // F32/F64 only.
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen); void RandInit(MatPtr& mat, float stddev, std::mt19937& gen);
// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If
// `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is typically 4KiB.
// To avoid remote accesses, we would thus pad each row to that, which results
// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that
// by pulling rows forward by a cyclic offset, which is still a multiple of the
// cache line size. This requires an additional `Allocator::QuantumBytes()` of
// padding after also rounding up to that, which considerably increases size for
// tall and skinny tensors.
static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
return hwy::RoundUpTo(cols, quantum) + quantum;
}
// Constexpr version (upper bound) for allocating storage in MatMul.
template <typename T>
constexpr size_t MaxStrideForCyclicOffsets(size_t cols) {
constexpr size_t quantum = Allocator::MaxQuantum<T>();
return hwy::RoundUpTo(cols, quantum) + quantum;
}
// Our tensors are always row-major. This enum indicates how much (if any) // Our tensors are always row-major. This enum indicates how much (if any)
// padding comes after each row. // padding comes after each row.
enum class MatPadding { enum class MatPadding {
@ -373,11 +355,14 @@ enum class MatPadding {
// Enough to round up to an odd number of cache lines, which can reduce // Enough to round up to an odd number of cache lines, which can reduce
// cache conflict misses or 4K aliasing. // cache conflict misses or 4K aliasing.
kOdd, kOdd,
// Enough to enable the "cyclic offsets" optimization for `MatMul`.
kCyclic,
}; };
// Type-erased, allows storing `AlignedPtr2<T[]>` for various T in the same // The stride (offset in elements between rows) that `MatOwner/MatStorageT`
// will use.
size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
size_t line_bytes);
// Type-erased, allows storing `AlignedPtr<T[]>` for various T in the same
// vector. // vector.
class MatOwner { class MatOwner {
public: public:
@ -390,7 +375,7 @@ class MatOwner {
void AllocateFor(MatPtr& mat, MatPadding padding); void AllocateFor(MatPtr& mat, MatPadding padding);
private: private:
AlignedPtr2<uint8_t[]> storage_; AlignedPtr<uint8_t[]> storage_;
}; };
// Multiple `MatOwner`, with support for parallel allocation. // Multiple `MatOwner`, with support for parallel allocation.
@ -443,84 +428,40 @@ MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
} }
// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with // Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with
// seekable (non-NUQ) T. This has less metadata, but support for cyclic offsets. // seekable (non-NUQ) T.
#pragma pack(push, 1) // power of two size #pragma pack(push, 1) // power of two size
template <typename T> template <typename T>
class RowPtr { class RowPtr {
public: public:
RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols, RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
size_t stride)
: row0_(row0), : row0_(row0),
stride_(stride),
// TODO: disabled because otherwise we see non-deterministic results.
row_mask_(0),
// static_cast<uint32_t>(allocator.QuantumStepMask() & 0xFFFFFFFFu)),
cols_(static_cast<uint32_t>(cols)), cols_(static_cast<uint32_t>(cols)),
step_bytes_(static_cast<uint32_t>(allocator.StepBytes())), stride_(static_cast<uint32_t>(stride)) {
quantum_bytes_(allocator.QuantumBytes()) {
HWY_DASSERT(stride >= cols); HWY_DASSERT(stride >= cols);
HWY_DASSERT(row_mask_ != ~uint32_t{0});
if (stride < StrideForCyclicOffsets(cols, quantum_bytes_ / sizeof(T))) {
row_mask_ = 0;
if constexpr (HWY_IS_DEBUG_BUILD) {
static bool once;
if (stride != cols && !once) {
once = true;
HWY_WARN(
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
"T=%zu; this forces us to disable cyclic offsets.",
stride, cols, sizeof(T));
}
}
}
} }
RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols) RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
: RowPtr(allocator, row0, cols, cols) {}
T* HWY_RESTRICT Row(size_t r) const { T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
// How much of the previous row's padding to consume.
const size_t pad_bytes = (r & row_mask_) * step_bytes_;
HWY_DASSERT(pad_bytes < static_cast<size_t>(quantum_bytes_));
return row0_ + stride_ * r - pad_bytes;
}
size_t Cols() const { return static_cast<size_t>(cols_); } size_t Cols() const { return static_cast<size_t>(cols_); }
size_t Stride() const { return stride_; } size_t Stride() const { return static_cast<size_t>(stride_); }
void SetStride(size_t stride) { void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols()); HWY_DASSERT(stride >= Cols());
stride_ = stride; stride_ = stride;
// The caller might not have padded enough, so disable the padding in Row().
// Rows will now be exactly `stride` elements apart. This is used when
// writing to the KV cache via MatMul.
row_mask_ = 0;
} }
// Returns 2D subrange whose top-left is `r, c` and width is `cols`. // Returns 2D subrange whose top-left is `r, c` and width is `cols`.
RowPtr<T> View(size_t r, size_t c, size_t cols) const { RowPtr<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < Cols()); HWY_DASSERT(c < Cols());
HWY_DASSERT(cols <= Cols() - c); HWY_DASSERT(cols <= Cols() - c);
return RowPtr<T>(Row(r) + c, cols, stride_, row_mask_, step_bytes_, return RowPtr<T>(Row(r) + c, cols, stride_);
quantum_bytes_);
} }
private: private:
// For `View()`.
RowPtr(T* new_row0, size_t new_cols, size_t stride, uint32_t row_mask,
uint32_t step_bytes, uint32_t quantum_bytes)
: row0_(new_row0),
stride_(stride),
row_mask_(row_mask),
cols_(new_cols),
step_bytes_(step_bytes),
quantum_bytes_(quantum_bytes) {}
T* HWY_RESTRICT row0_; T* HWY_RESTRICT row0_;
size_t stride_;
uint32_t row_mask_;
uint32_t cols_; uint32_t cols_;
uint32_t step_bytes_; uint32_t stride_;
uint32_t quantum_bytes_;
}; };
#pragma pack(pop) #pragma pack(pop)
@ -528,14 +469,12 @@ using RowPtrBF = RowPtr<BF16>;
using RowPtrF = RowPtr<float>; using RowPtrF = RowPtr<float>;
using RowPtrD = RowPtr<double>; using RowPtrD = RowPtr<double>;
// TODO: remove allocator arg once kCyclic is removed.
template <typename T> template <typename T>
RowPtr<T> RowPtrFromMat(const Allocator& allocator, RowPtr<T> RowPtrFromMat(const MatPtrT<T>& row_vectors) {
const MatPtrT<T>& row_vectors) {
// RowPtr is non-const for MatMul C, but is also used for A which is const. // RowPtr is non-const for MatMul C, but is also used for A which is const.
// Callers are responsible for checking their usage of RowPtr. // Callers are responsible for checking their usage of RowPtr.
return RowPtr<T>(allocator, const_cast<T*>(row_vectors.Row(0)), return RowPtr<T>(const_cast<T*>(row_vectors.Row(0)), row_vectors.Cols(),
row_vectors.Cols(), row_vectors.Stride()); row_vectors.Stride());
} }
} // namespace gcpp } // namespace gcpp

View File

@ -38,7 +38,7 @@ namespace gcpp {
// Page-aligned on NUMA systems so we can bind to a NUMA node. This also allows // Page-aligned on NUMA systems so we can bind to a NUMA node. This also allows
// moving because it is a typedef to `std::unique_ptr`. // moving because it is a typedef to `std::unique_ptr`.
using PoolPtr = AlignedClassPtr2<hwy::ThreadPool>; using PoolPtr = AlignedClassPtr<hwy::ThreadPool>;
// Creates a hierarchy of thread pools according to `BoundedTopology`: one with // Creates a hierarchy of thread pools according to `BoundedTopology`: one with
// a thread per enabled package; for each of those, one with a thread per // a thread per enabled package; for each of those, one with a thread per