mirror of https://github.com/google/gemma.cpp.git
Cleanup: remove unused kCyclic, remove 2 suffix
Also remove now unused allocator arg and fix warnings (cast, struct/class mismatch) PiperOrigin-RevId: 758098495
This commit is contained in:
parent
ba21e3beb4
commit
d538a6d6c6
|
|
@ -42,8 +42,6 @@ static ModelConfig ConfigNoSSM() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigBaseGemmaV1() { return ConfigNoSSM(); }
|
||||
|
||||
static ModelConfig ConfigBaseGemmaV2() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.att_cap = 50.0f;
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@
|
|||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
|
|
@ -260,8 +259,7 @@ class GemmaAttention {
|
|||
const size_t w1_rows = heads * layer_config_.QStride();
|
||||
w_q1.ShrinkRows(w1_rows);
|
||||
MatMul(activations_.pre_att_rms_out, w_q1,
|
||||
/*add=*/nullptr, *activations_.env,
|
||||
RowPtrFromMat(allocator_, activations_.q));
|
||||
/*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q));
|
||||
|
||||
if (is_mha_) {
|
||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||
|
|
@ -284,7 +282,7 @@ class GemmaAttention {
|
|||
const size_t kv_ofs =
|
||||
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||
RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols);
|
||||
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
||||
kv_rows.SetStride(cache_pos_size_);
|
||||
MatMul(activations_.pre_att_rms_out, w_q2,
|
||||
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||
|
|
@ -490,7 +488,7 @@ class GemmaAttention {
|
|||
? layer_weights_.attention_output_biases.PackedScale1()
|
||||
: nullptr;
|
||||
MatMul(activations_.att_out, layer_weights_.att_weights, add,
|
||||
*activations_.env, RowPtrFromMat(allocator_, activations_.att_sums));
|
||||
*activations_.env, RowPtrFromMat(activations_.att_sums));
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -556,7 +554,6 @@ class GemmaAttention {
|
|||
layer_weights_(*layer_weights),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
allocator_(ctx.allocator),
|
||||
pool_(ctx.pools.Pool(0)) {
|
||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
|
||||
|
|
@ -586,7 +583,6 @@ class GemmaAttention {
|
|||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const hwy::Divisor& div_seq_len_;
|
||||
const KVCaches& kv_caches_;
|
||||
const Allocator& allocator_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
||||
|
|
@ -631,7 +627,7 @@ class VitAttention {
|
|||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
|
||||
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
|
||||
RowPtrFromMat(allocator_, qkv));
|
||||
RowPtrFromMat(qkv));
|
||||
}
|
||||
|
||||
// TODO(philculliton): transition fully to MatMul.
|
||||
|
|
@ -671,7 +667,7 @@ class VitAttention {
|
|||
});
|
||||
|
||||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(allocator_, C));
|
||||
MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(C));
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Row(task);
|
||||
|
|
@ -737,7 +733,7 @@ class VitAttention {
|
|||
// att_weights and att_out are concatenated heads, each of length
|
||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
auto att_sums = RowPtrFromMat(allocator_, activations_.att_sums);
|
||||
auto att_sums = RowPtrFromMat(activations_.att_sums);
|
||||
MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
|
||||
*activations_.env, att_sums);
|
||||
}
|
||||
|
|
@ -750,7 +746,6 @@ class VitAttention {
|
|||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
layer_config_(layer_weights->layer_config),
|
||||
allocator_(activations.env->ctx.allocator),
|
||||
pool_(activations.env->ctx.pools.Pool(0)) {}
|
||||
|
||||
HWY_INLINE void operator()() {
|
||||
|
|
@ -769,7 +764,6 @@ class VitAttention {
|
|||
Activations& activations_;
|
||||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const LayerConfig& layer_config_;
|
||||
const Allocator& allocator_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
||||
|
|
@ -832,10 +826,9 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
|
|||
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const Allocator& allocator = activations.env->ctx.allocator;
|
||||
auto hidden_activations = RowPtrFromMat(allocator, activations.C1);
|
||||
auto multiplier = RowPtrFromMat(allocator, activations.C2);
|
||||
auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
|
||||
auto hidden_activations = RowPtrFromMat(activations.C1);
|
||||
auto multiplier = RowPtrFromMat(activations.C2);
|
||||
auto ffw_out = RowPtrFromMat(activations.ffw_out);
|
||||
|
||||
using WeightT = typename decltype(layer_weights->gating_einsum_w)::T;
|
||||
|
||||
|
|
@ -881,22 +874,16 @@ HWY_NOINLINE void FFWVit(Activations& activations,
|
|||
const float* output_bias =
|
||||
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const Allocator& allocator = activations.env->ctx.allocator;
|
||||
auto hidden_activations = RowPtrFromMat(allocator, activations.C1);
|
||||
auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
|
||||
|
||||
// Compute the hidden layer activations.
|
||||
MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1,
|
||||
*activations.env, hidden_activations);
|
||||
*activations.env, RowPtrFromMat(activations.C1));
|
||||
|
||||
// Activation (Gelu), store in act.
|
||||
RowPtrF multiplier = RowPtrF(allocator, nullptr, 0);
|
||||
// Activation (Gelu), store in C1.
|
||||
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias,
|
||||
*activations.env, ffw_out);
|
||||
*activations.env, RowPtrFromMat(activations.ffw_out));
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
|
|
@ -932,11 +919,10 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
|
|||
}
|
||||
|
||||
const size_t model_dim = weights.weights_config.model_dim;
|
||||
const size_t vocab_size = weights.weights_config.vocab_size;
|
||||
const float emb_scaling = EmbeddingScaling(model_dim);
|
||||
|
||||
HWY_DASSERT(token >= 0);
|
||||
HWY_DASSERT(token < static_cast<int>(vocab_size));
|
||||
HWY_DASSERT(token < static_cast<int>(weights.weights_config.vocab_size));
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
// Using `Stride` to compute the offset works for both NUQ (because we use an
|
||||
|
|
@ -1263,7 +1249,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMul(activations.x, weights.vit_img_head_kernel,
|
||||
weights.vit_img_head_bias.PackedScale1(), *activations.env,
|
||||
RowPtrFromMat(activations.env->ctx.allocator, image_tokens));
|
||||
RowPtrFromMat(image_tokens));
|
||||
}
|
||||
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
|
|
@ -1403,7 +1389,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
|
|||
// Compute logits from last layer activations.
|
||||
MatMul(activations.x, weights.embedder_input_embedding,
|
||||
/*add=*/nullptr, *activations.env,
|
||||
RowPtrFromMat(activations.env->ctx.allocator, activations.logits));
|
||||
RowPtrFromMat(activations.logits));
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
|||
// - per-query position within the tokens sequence
|
||||
// - layer index (or -1 for post-norm output)
|
||||
// - activations
|
||||
class Activations;
|
||||
struct Activations;
|
||||
using ActivationsObserverFunc =
|
||||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@
|
|||
namespace gcpp {
|
||||
|
||||
struct KVCache {
|
||||
KVCache() = default; // for std::vector.
|
||||
KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||
|
||||
// Returns a deep copy of the KVCache.
|
||||
|
|
|
|||
2
io/io.cc
2
io/io.cc
|
|
@ -115,7 +115,7 @@ class FilePosix : public File {
|
|||
#endif
|
||||
|
||||
return MapPtr(static_cast<const uint8_t*>(mapping),
|
||||
DeleterFunc2([mapping_size](void* ptr) {
|
||||
DeleterFunc([mapping_size](void* ptr) {
|
||||
HWY_ASSERT(munmap(ptr, mapping_size) == 0);
|
||||
}));
|
||||
}
|
||||
|
|
|
|||
2
io/io.h
2
io/io.h
|
|
@ -33,7 +33,7 @@ namespace gcpp {
|
|||
// prefer to define Exists inline because there are multiple io*.cc files.
|
||||
struct Path;
|
||||
|
||||
using MapPtr = AlignedPtr2<const uint8_t[]>;
|
||||
using MapPtr = AlignedPtr<const uint8_t[]>;
|
||||
|
||||
// Abstract base class enables multiple I/O backends in the same binary.
|
||||
class File {
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class FileWin : public File {
|
|||
void* ptr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||
if (!ptr) return MapPtr();
|
||||
return MapPtr(static_cast<const uint8_t*>(ptr),
|
||||
DeleterFunc2([hMapping](void* ptr) {
|
||||
DeleterFunc([hMapping](void* ptr) {
|
||||
HWY_ASSERT(UnmapViewOfFile(ptr));
|
||||
HWY_ASSERT(CloseHandle(hMapping));
|
||||
}));
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch);
|
||||
const RowPtr<TC> C = RowPtrFromMat(c_batch);
|
||||
|
||||
// Fewer reps for large batch sizes, which take longer.
|
||||
const size_t num_samples = M < 32 ? 20 : 12;
|
||||
|
|
|
|||
|
|
@ -1140,7 +1140,7 @@ void TestAllDot() {
|
|||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
constexpr size_t kTimeReps = hn::AdjustedReps(10);
|
||||
std::array<double, kTimeReps> elapsed;
|
||||
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
|
||||
for (size_t time_rep = 0; time_rep < kTimeReps; ++time_rep) {
|
||||
const double start = hwy::platform::Now();
|
||||
dots[variant] +=
|
||||
CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
|
||||
|
|
|
|||
|
|
@ -864,7 +864,7 @@ class MMPerPackage {
|
|||
: args_(args),
|
||||
pkg_idx_(pkg_idx),
|
||||
// May be overwritten with a view of A, if already BF16.
|
||||
A_(args_.env->storage.A(args.env->ctx.allocator, pkg_idx, A.Extents())),
|
||||
A_(args_.env->storage.A(pkg_idx, A.Extents())),
|
||||
range_np_(range_np),
|
||||
mr_(config.MR()),
|
||||
ranges_mc_(config.RangesOfMC(A.Extents().rows)),
|
||||
|
|
@ -905,9 +905,8 @@ class MMPerPackage {
|
|||
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
||||
// allocation avoids passing a worker index.
|
||||
static constexpr size_t B_stride_max_ =
|
||||
MaxStrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
|
||||
static constexpr size_t B_storage_max_ =
|
||||
kNR * B_stride_max_ + Allocator::MaxQuantum<BF16>();
|
||||
MMStorage::kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16);
|
||||
static constexpr size_t B_storage_max_ = kNR * B_stride_max_;
|
||||
|
||||
// Granularity of `ForNP`. B rows produce C columns, so we
|
||||
// want a multiple of the line size to prevent false sharing.
|
||||
|
|
@ -928,15 +927,14 @@ class MMPerPackage {
|
|||
const size_t K = range_K.Num();
|
||||
const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K);
|
||||
const size_t B_stride =
|
||||
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>());
|
||||
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
|
||||
|
||||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||
args_.env->parallel.ForNP(
|
||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K,
|
||||
B_stride);
|
||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
|
|
@ -971,8 +969,8 @@ class MMPerPackage {
|
|||
const size_t kc = range_kc.Num();
|
||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
const RowPtrBF B_view(
|
||||
args_.env->ctx.allocator, B_storage, kc,
|
||||
StrideForCyclicOffsets(kc, args_.env->ctx.allocator.Quantum<BF16>()));
|
||||
B_storage, kc,
|
||||
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
|
|
@ -1028,7 +1026,7 @@ class MMPerPackage {
|
|||
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||
const size_t K = range_K.Num();
|
||||
const size_t B_stride =
|
||||
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>());
|
||||
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
|
||||
|
||||
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
||||
// except for the profiler strings and `out_tag`.
|
||||
|
|
@ -1037,8 +1035,7 @@ class MMPerPackage {
|
|||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K,
|
||||
B_stride);
|
||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
|
|
@ -1064,8 +1061,8 @@ class MMPerPackage {
|
|||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||
const size_t kc_max = ranges_kc_.TaskSize();
|
||||
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
||||
const size_t B_stride = StrideForCyclicOffsets(
|
||||
kc_max, args_.env->ctx.allocator.Quantum<BF16>());
|
||||
const size_t B_stride =
|
||||
Stride(MatPadding::kOdd, kc_max, sizeof(BF16), line_bytes_);
|
||||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
||||
|
|
@ -1091,8 +1088,7 @@ class MMPerPackage {
|
|||
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, kc_max,
|
||||
B_stride);
|
||||
const RowPtrBF B_view(B_storage, kc_max, B_stride);
|
||||
|
||||
// Peel off the first iteration of the kc loop: avoid
|
||||
// zero-initializing `partial` by writing into it.
|
||||
|
|
@ -1172,13 +1168,12 @@ class MMPerPackage {
|
|||
// Autotuning wrapper for `DoDecompressA`.
|
||||
template <typename TA>
|
||||
HWY_INLINE RowPtrBF DecompressA(const MatPtrT<TA>& A) const {
|
||||
const Allocator& allocator = args_.env->ctx.allocator;
|
||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||
// If already BF16, maybe return a view:
|
||||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||
// Only if no zero-padding required.
|
||||
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
||||
if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(allocator, A);
|
||||
if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(A);
|
||||
}
|
||||
|
||||
if (HWY_LIKELY(autotune.Best())) {
|
||||
|
|
|
|||
10
ops/matmul.h
10
ops/matmul.h
|
|
@ -217,8 +217,7 @@ class MMStorage {
|
|||
: partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
|
||||
MatPadding::kOdd),
|
||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||
partial_(allocator, partial_storage_.Row(0), kMaxN,
|
||||
partial_storage_.Stride()) {
|
||||
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
||||
// Per-package allocation so each can decompress A into its own copy.
|
||||
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
||||
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||
|
|
@ -240,12 +239,11 @@ class MMStorage {
|
|||
}
|
||||
|
||||
// Returns per-package matrix view.
|
||||
RowPtrBF A(const Allocator& allocator, size_t pkg_idx,
|
||||
const Extents2D& extents) const {
|
||||
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
|
||||
HWY_DASSERT(extents.rows <= kMaxM);
|
||||
HWY_DASSERT(extents.cols <= kMaxK);
|
||||
return RowPtrBF(allocator, const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
|
||||
extents.cols, pkg_A_[pkg_idx]->Stride());
|
||||
return RowPtrBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)), extents.cols,
|
||||
pkg_A_[pkg_idx]->Stride());
|
||||
}
|
||||
|
||||
RowPtrD Partial() const { return partial_; }
|
||||
|
|
|
|||
|
|
@ -205,7 +205,6 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
|||
template <typename TA, typename TB = TA, typename TC = float>
|
||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||
MatMulEnv& env, int line) {
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
|
||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
|
||||
|
|
@ -229,8 +228,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
add_storage.SetScale(1.0f);
|
||||
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
const RowPtr<TC> C_slow = RowPtrFromMat(allocator, c_slow_batch);
|
||||
const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch);
|
||||
const RowPtr<TC> C_slow = RowPtrFromMat(c_slow_batch);
|
||||
const RowPtr<TC> C = RowPtrFromMat(c_batch);
|
||||
|
||||
MatMulSlow(a, b_trans, add_row, env, C_slow);
|
||||
// A few reps to get coverage of the various autotuned code paths.
|
||||
|
|
|
|||
|
|
@ -507,7 +507,7 @@ void TestLayerNormSimple() {
|
|||
const size_t kSize = 52;
|
||||
std::vector<float> values(kSize);
|
||||
// Alternating 1.0/-1.0, so mean=0.0, var=1.0, rsqrt(var+epsilon)=0.9999995
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
for (size_t i = 0; i < kSize; ++i) {
|
||||
values[i] = (i % 2 == 0) ? 1.0f : -1.0f;
|
||||
}
|
||||
std::vector<float> scale(kSize, 1.2f);
|
||||
|
|
|
|||
|
|
@ -132,7 +132,11 @@ size_t DetectTotalMiB(size_t page_bytes) {
|
|||
|
||||
Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
|
||||
line_bytes_ = DetectLineBytes();
|
||||
// Ensure MaxLineBytes() is an upper bound.
|
||||
HWY_ASSERT(MaxLineBytes() >= LineBytes());
|
||||
|
||||
vector_bytes_ = hwy::VectorBytes();
|
||||
|
||||
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
|
||||
base_page_bytes_ = DetectPageSize();
|
||||
quantum_bytes_ = step_bytes_; // may overwrite below
|
||||
|
|
@ -165,8 +169,6 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
|
|||
// Ensure pages meet the alignment requirements of `AllocBytes`.
|
||||
HWY_ASSERT(base_page_bytes_ >= quantum_bytes_);
|
||||
quantum_bytes_ = base_page_bytes_;
|
||||
// Ensure MaxQuantum() is an upper bound.
|
||||
HWY_ASSERT(MaxQuantum<uint8_t>() >= Quantum<uint8_t>());
|
||||
should_bind_ = true;
|
||||
} else {
|
||||
HWY_WARN(
|
||||
|
|
@ -175,9 +177,6 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
HWY_DASSERT(quantum_bytes_ % step_bytes_ == 0);
|
||||
quantum_step_mask_ = quantum_bytes_ / step_bytes_ - 1;
|
||||
}
|
||||
|
||||
size_t Allocator::FreeMiB() const {
|
||||
|
|
@ -201,7 +200,7 @@ size_t Allocator::FreeMiB() const {
|
|||
#endif
|
||||
}
|
||||
|
||||
AlignedPtr2<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
|
||||
AlignedPtr<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
|
||||
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
|
||||
// defends against 2K aliasing.
|
||||
if (!should_bind_) {
|
||||
|
|
@ -217,10 +216,9 @@ AlignedPtr2<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
|
|||
// alignment scheme in aligned_allocator.cc and does not work for
|
||||
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
|
||||
// pointer in our own deleter.
|
||||
return AlignedPtr2<uint8_t[]>(p.release(), DeleterFunc2([](void* ptr) {
|
||||
hwy::FreeAlignedBytes(ptr, nullptr,
|
||||
nullptr);
|
||||
}));
|
||||
return AlignedPtr<uint8_t[]>(p.release(), DeleterFunc([](void* ptr) {
|
||||
hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
|
||||
}));
|
||||
}
|
||||
|
||||
// Binding, or large vector/cache line size: use platform-specific allocator.
|
||||
|
|
@ -234,17 +232,16 @@ AlignedPtr2<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
|
|||
const int fd = -1;
|
||||
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
|
||||
if (p == MAP_FAILED) p = nullptr;
|
||||
return AlignedPtr2<uint8_t[]>(static_cast<uint8_t*>(p),
|
||||
DeleterFunc2([bytes](void* ptr) {
|
||||
HWY_ASSERT(munmap(ptr, bytes) == 0);
|
||||
}));
|
||||
return AlignedPtr<uint8_t[]>(
|
||||
static_cast<uint8_t*>(p),
|
||||
DeleterFunc([bytes](void* ptr) { HWY_ASSERT(munmap(ptr, bytes) == 0); }));
|
||||
#elif HWY_OS_WIN
|
||||
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
|
||||
return AlignedPtr2<uint8_t[]>(
|
||||
return AlignedPtr<uint8_t[]>(
|
||||
static_cast<uint8_t*>(_aligned_malloc(bytes, alignment)),
|
||||
DeleterFunc2([](void* ptr) { _aligned_free(ptr); }));
|
||||
DeleterFunc([](void* ptr) { _aligned_free(ptr); }));
|
||||
#else
|
||||
return AlignedPtr2<uint8_t[]>(nullptr, DeleterFunc2());
|
||||
return AlignedPtr<uint8_t[]>(nullptr, DeleterFunc());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,13 +34,13 @@ namespace gcpp {
|
|||
|
||||
// Custom deleter for types without a dtor, but where the deallocation requires
|
||||
// state, e.g. a lambda with *by-value* capture.
|
||||
class DeleterFunc2 {
|
||||
class DeleterFunc {
|
||||
public:
|
||||
// `MatOwnerT` requires this to be default-constructible.
|
||||
DeleterFunc2() = default;
|
||||
DeleterFunc() = default;
|
||||
|
||||
template <class Closure>
|
||||
DeleterFunc2(const Closure& free_closure) : free_func_(free_closure) {}
|
||||
DeleterFunc(const Closure& free_closure) : free_func_(free_closure) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(T* p) const {
|
||||
|
|
@ -52,10 +52,10 @@ class DeleterFunc2 {
|
|||
};
|
||||
|
||||
// Wrapper that also calls the destructor for each element being deallocated.
|
||||
class DeleterDtor2 {
|
||||
class DeleterDtor {
|
||||
public:
|
||||
DeleterDtor2() {}
|
||||
DeleterDtor2(size_t num, DeleterFunc2 free) : num_(num), free_(free) {}
|
||||
DeleterDtor() {}
|
||||
DeleterDtor(size_t num, DeleterFunc free) : num_(num), free_(free) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(T* p) const {
|
||||
|
|
@ -67,15 +67,15 @@ class DeleterDtor2 {
|
|||
|
||||
private:
|
||||
size_t num_;
|
||||
DeleterFunc2 free_;
|
||||
DeleterFunc free_;
|
||||
};
|
||||
|
||||
// Unique (move-only) pointer to aligned POD T, which can be an array or class.
|
||||
template <typename T>
|
||||
using AlignedPtr2 = std::unique_ptr<T, DeleterFunc2>;
|
||||
using AlignedPtr = std::unique_ptr<T, DeleterFunc>;
|
||||
// Unique (move-only) pointer to an aligned array of non-POD T.
|
||||
template <typename T>
|
||||
using AlignedClassPtr2 = std::unique_ptr<T, DeleterDtor2>;
|
||||
using AlignedClassPtr = std::unique_ptr<T, DeleterDtor>;
|
||||
|
||||
// Both allocation, binding, and row accessors depend on the sizes of memory
|
||||
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
|
||||
|
|
@ -90,26 +90,24 @@ class Allocator {
|
|||
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
|
||||
// ranges such that there will be no false sharing.
|
||||
size_t LineBytes() const { return line_bytes_; }
|
||||
// Upper bound on `LineBytes()`, for stack allocations.
|
||||
static constexpr size_t MaxLineBytes() { return 256; }
|
||||
// Bytes per full vector. Used to compute loop steps.
|
||||
size_t VectorBytes() const { return vector_bytes_; }
|
||||
// Work granularity that avoids false sharing and partial vectors.
|
||||
// = HWY_MAX(LineBytes(), VectorBytes())
|
||||
size_t StepBytes() const { return step_bytes_; }
|
||||
|
||||
// File size multiple required for memory mapping.
|
||||
size_t BasePageBytes() const { return base_page_bytes_; }
|
||||
|
||||
// Either StepBytes or BasePageBytes if NUMA.
|
||||
size_t QuantumBytes() const { return quantum_bytes_; }
|
||||
template <typename T>
|
||||
// For rounding down elements to the page size in `BindB/BindC`.
|
||||
size_t Quantum() const {
|
||||
return QuantumBytes() / sizeof(T);
|
||||
}
|
||||
// Upper bound on `Quantum()`, for stack allocations.
|
||||
template <typename T>
|
||||
static constexpr size_t MaxQuantum() {
|
||||
return 4096 / sizeof(T);
|
||||
}
|
||||
// = QuantumBytes() / StepBytes() - 1
|
||||
size_t QuantumStepMask() const { return quantum_step_mask_; }
|
||||
|
||||
// L1 and L2 are typically per core.
|
||||
size_t L1Bytes() const { return l1_bytes_; }
|
||||
|
|
@ -123,35 +121,35 @@ class Allocator {
|
|||
// Returns byte pointer aligned to `QuantumBytes()`, without calling
|
||||
// constructors nor destructors on deletion. Type-erased so this can be
|
||||
// implemented in `allocator.cc` and called by `MatOwner`.
|
||||
AlignedPtr2<uint8_t[]> AllocBytes(size_t bytes) const;
|
||||
AlignedPtr<uint8_t[]> AllocBytes(size_t bytes) const;
|
||||
|
||||
// Returns pointer aligned to `QuantumBytes()`, without calling constructors
|
||||
// nor destructors on deletion.
|
||||
template <typename T>
|
||||
AlignedPtr2<T[]> Alloc(size_t num) const {
|
||||
AlignedPtr<T[]> Alloc(size_t num) const {
|
||||
const size_t bytes = num * sizeof(T);
|
||||
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
|
||||
HWY_ASSERT(bytes / sizeof(T) == num);
|
||||
|
||||
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes);
|
||||
return AlignedPtr2<T[]>(HWY_RCAST_ALIGNED(T*, p8.release()),
|
||||
p8.get_deleter());
|
||||
AlignedPtr<uint8_t[]> p8 = AllocBytes(bytes);
|
||||
return AlignedPtr<T[]>(HWY_RCAST_ALIGNED(T*, p8.release()),
|
||||
p8.get_deleter());
|
||||
}
|
||||
|
||||
// Same as Alloc, but calls constructor(s) with `args` and the deleter will
|
||||
// call destructor(s).
|
||||
template <typename T, class... Args>
|
||||
AlignedClassPtr2<T> AllocClasses(size_t num, Args&&... args) const {
|
||||
AlignedClassPtr<T> AllocClasses(size_t num, Args&&... args) const {
|
||||
const size_t bytes = num * sizeof(T);
|
||||
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
|
||||
HWY_ASSERT(bytes / sizeof(T) == num);
|
||||
|
||||
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes);
|
||||
AlignedPtr<uint8_t[]> p8 = AllocBytes(bytes);
|
||||
T* p = HWY_RCAST_ALIGNED(T*, p8.release());
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
new (p + i) T(std::forward<Args>(args)...);
|
||||
}
|
||||
return AlignedClassPtr2<T>(p, DeleterDtor2(num, p8.get_deleter()));
|
||||
return AlignedClassPtr<T>(p, DeleterDtor(num, p8.get_deleter()));
|
||||
}
|
||||
|
||||
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
|
||||
|
|
@ -170,7 +168,6 @@ class Allocator {
|
|||
size_t step_bytes_;
|
||||
size_t base_page_bytes_;
|
||||
size_t quantum_bytes_;
|
||||
size_t quantum_step_mask_;
|
||||
|
||||
size_t l1_bytes_ = 0;
|
||||
size_t l2_bytes_ = 0;
|
||||
|
|
|
|||
41
util/mat.cc
41
util/mat.cc
|
|
@ -89,38 +89,31 @@ void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) {
|
|||
}
|
||||
}
|
||||
|
||||
// Returns `num` rounded up to an odd number of cache lines. This would also
|
||||
// prevent 4K aliasing and is coprime with the cache associativity, which
|
||||
// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`.
|
||||
static size_t RoundUpToOddLines(size_t num, size_t line_bytes,
|
||||
size_t element_bytes) {
|
||||
HWY_DASSERT(line_bytes >= 32);
|
||||
HWY_DASSERT(line_bytes % element_bytes == 0);
|
||||
const size_t lines = hwy::DivCeil(num * element_bytes, line_bytes);
|
||||
const size_t padded_num = (lines | 1) * line_bytes / element_bytes;
|
||||
HWY_DASSERT(padded_num >= num);
|
||||
return padded_num;
|
||||
}
|
||||
|
||||
static size_t Stride(const Allocator& allocator, const MatPtr& mat,
|
||||
MatPadding padding) {
|
||||
size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
|
||||
size_t line_bytes) {
|
||||
switch (padding) {
|
||||
case MatPadding::kPacked:
|
||||
default:
|
||||
return mat.Cols();
|
||||
case MatPadding::kOdd:
|
||||
return RoundUpToOddLines(mat.Cols(), allocator.LineBytes(),
|
||||
mat.ElementBytes());
|
||||
case MatPadding::kCyclic:
|
||||
return StrideForCyclicOffsets(
|
||||
mat.Cols(), allocator.QuantumBytes() / mat.ElementBytes());
|
||||
return cols;
|
||||
case MatPadding::kOdd: {
|
||||
// Round up to an odd number of cache lines to prevent 4K aliasing and
|
||||
// reduce conflict misses (coprime with the cache associativity).
|
||||
HWY_DASSERT(line_bytes >= 32);
|
||||
HWY_DASSERT(line_bytes % element_bytes == 0);
|
||||
const size_t lines = hwy::DivCeil(cols * element_bytes, line_bytes);
|
||||
const size_t padded_cols = (lines | 1) * line_bytes / element_bytes;
|
||||
HWY_DASSERT(padded_cols >= cols);
|
||||
return padded_cols;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MatOwner::AllocateFor(MatPtr& mat, const MatPadding padding) {
|
||||
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
||||
const bool is_nuq = mat.GetType() == Type::kNUQ;
|
||||
if (is_nuq) padding = MatPadding::kPacked;
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
const size_t stride = is_nuq ? mat.Cols() : Stride(allocator, mat, padding);
|
||||
const size_t stride =
|
||||
Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes());
|
||||
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
|
||||
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
|
||||
// might not be enough, hence add extra. `MatT` is at least one byte, which
|
||||
|
|
|
|||
99
util/mat.h
99
util/mat.h
|
|
@ -28,7 +28,7 @@
|
|||
#include "compression/shared.h" // Type
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "io/fields.h"
|
||||
#include "util/allocator.h" // AlignedPtr2
|
||||
#include "util/allocator.h" // AlignedPtr
|
||||
#include "util/basics.h" // Extents2D
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -339,24 +339,6 @@ void ZeroInit(MatPtr& mat);
|
|||
// F32/F64 only.
|
||||
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen);
|
||||
|
||||
// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If
|
||||
// `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is typically 4KiB.
|
||||
// To avoid remote accesses, we would thus pad each row to that, which results
|
||||
// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that
|
||||
// by pulling rows forward by a cyclic offset, which is still a multiple of the
|
||||
// cache line size. This requires an additional `Allocator::QuantumBytes()` of
|
||||
// padding after also rounding up to that, which considerably increases size for
|
||||
// tall and skinny tensors.
|
||||
static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
|
||||
return hwy::RoundUpTo(cols, quantum) + quantum;
|
||||
}
|
||||
// Constexpr version (upper bound) for allocating storage in MatMul.
|
||||
template <typename T>
|
||||
constexpr size_t MaxStrideForCyclicOffsets(size_t cols) {
|
||||
constexpr size_t quantum = Allocator::MaxQuantum<T>();
|
||||
return hwy::RoundUpTo(cols, quantum) + quantum;
|
||||
}
|
||||
|
||||
// Our tensors are always row-major. This enum indicates how much (if any)
|
||||
// padding comes after each row.
|
||||
enum class MatPadding {
|
||||
|
|
@ -373,11 +355,14 @@ enum class MatPadding {
|
|||
// Enough to round up to an odd number of cache lines, which can reduce
|
||||
// cache conflict misses or 4K aliasing.
|
||||
kOdd,
|
||||
// Enough to enable the "cyclic offsets" optimization for `MatMul`.
|
||||
kCyclic,
|
||||
};
|
||||
|
||||
// Type-erased, allows storing `AlignedPtr2<T[]>` for various T in the same
|
||||
// The stride (offset in elements between rows) that `MatOwner/MatStorageT`
|
||||
// will use.
|
||||
size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
|
||||
size_t line_bytes);
|
||||
|
||||
// Type-erased, allows storing `AlignedPtr<T[]>` for various T in the same
|
||||
// vector.
|
||||
class MatOwner {
|
||||
public:
|
||||
|
|
@ -390,7 +375,7 @@ class MatOwner {
|
|||
void AllocateFor(MatPtr& mat, MatPadding padding);
|
||||
|
||||
private:
|
||||
AlignedPtr2<uint8_t[]> storage_;
|
||||
AlignedPtr<uint8_t[]> storage_;
|
||||
};
|
||||
|
||||
// Multiple `MatOwner`, with support for parallel allocation.
|
||||
|
|
@ -443,84 +428,40 @@ MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
|
|||
}
|
||||
|
||||
// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with
|
||||
// seekable (non-NUQ) T. This has less metadata, but support for cyclic offsets.
|
||||
// seekable (non-NUQ) T.
|
||||
#pragma pack(push, 1) // power of two size
|
||||
template <typename T>
|
||||
class RowPtr {
|
||||
public:
|
||||
RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols,
|
||||
size_t stride)
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||
: row0_(row0),
|
||||
stride_(stride),
|
||||
// TODO: disabled because otherwise we see non-deterministic results.
|
||||
row_mask_(0),
|
||||
// static_cast<uint32_t>(allocator.QuantumStepMask() & 0xFFFFFFFFu)),
|
||||
cols_(static_cast<uint32_t>(cols)),
|
||||
step_bytes_(static_cast<uint32_t>(allocator.StepBytes())),
|
||||
quantum_bytes_(allocator.QuantumBytes()) {
|
||||
stride_(static_cast<uint32_t>(stride)) {
|
||||
HWY_DASSERT(stride >= cols);
|
||||
HWY_DASSERT(row_mask_ != ~uint32_t{0});
|
||||
if (stride < StrideForCyclicOffsets(cols, quantum_bytes_ / sizeof(T))) {
|
||||
row_mask_ = 0;
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
static bool once;
|
||||
if (stride != cols && !once) {
|
||||
once = true;
|
||||
HWY_WARN(
|
||||
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
|
||||
"T=%zu; this forces us to disable cyclic offsets.",
|
||||
stride, cols, sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols)
|
||||
: RowPtr(allocator, row0, cols, cols) {}
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
|
||||
|
||||
T* HWY_RESTRICT Row(size_t r) const {
|
||||
// How much of the previous row's padding to consume.
|
||||
const size_t pad_bytes = (r & row_mask_) * step_bytes_;
|
||||
HWY_DASSERT(pad_bytes < static_cast<size_t>(quantum_bytes_));
|
||||
return row0_ + stride_ * r - pad_bytes;
|
||||
}
|
||||
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
||||
size_t Cols() const { return static_cast<size_t>(cols_); }
|
||||
|
||||
size_t Stride() const { return stride_; }
|
||||
size_t Stride() const { return static_cast<size_t>(stride_); }
|
||||
void SetStride(size_t stride) {
|
||||
HWY_DASSERT(stride >= Cols());
|
||||
stride_ = stride;
|
||||
// The caller might not have padded enough, so disable the padding in Row().
|
||||
// Rows will now be exactly `stride` elements apart. This is used when
|
||||
// writing to the KV cache via MatMul.
|
||||
row_mask_ = 0;
|
||||
}
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
|
||||
HWY_DASSERT(c < Cols());
|
||||
HWY_DASSERT(cols <= Cols() - c);
|
||||
return RowPtr<T>(Row(r) + c, cols, stride_, row_mask_, step_bytes_,
|
||||
quantum_bytes_);
|
||||
return RowPtr<T>(Row(r) + c, cols, stride_);
|
||||
}
|
||||
|
||||
private:
|
||||
// For `View()`.
|
||||
RowPtr(T* new_row0, size_t new_cols, size_t stride, uint32_t row_mask,
|
||||
uint32_t step_bytes, uint32_t quantum_bytes)
|
||||
: row0_(new_row0),
|
||||
stride_(stride),
|
||||
row_mask_(row_mask),
|
||||
cols_(new_cols),
|
||||
step_bytes_(step_bytes),
|
||||
quantum_bytes_(quantum_bytes) {}
|
||||
|
||||
T* HWY_RESTRICT row0_;
|
||||
size_t stride_;
|
||||
uint32_t row_mask_;
|
||||
uint32_t cols_;
|
||||
uint32_t step_bytes_;
|
||||
uint32_t quantum_bytes_;
|
||||
uint32_t stride_;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
|
|
@ -528,14 +469,12 @@ using RowPtrBF = RowPtr<BF16>;
|
|||
using RowPtrF = RowPtr<float>;
|
||||
using RowPtrD = RowPtr<double>;
|
||||
|
||||
// TODO: remove allocator arg once kCyclic is removed.
|
||||
template <typename T>
|
||||
RowPtr<T> RowPtrFromMat(const Allocator& allocator,
|
||||
const MatPtrT<T>& row_vectors) {
|
||||
RowPtr<T> RowPtrFromMat(const MatPtrT<T>& row_vectors) {
|
||||
// RowPtr is non-const for MatMul C, but is also used for A which is const.
|
||||
// Callers are responsible for checking their usage of RowPtr.
|
||||
return RowPtr<T>(allocator, const_cast<T*>(row_vectors.Row(0)),
|
||||
row_vectors.Cols(), row_vectors.Stride());
|
||||
return RowPtr<T>(const_cast<T*>(row_vectors.Row(0)), row_vectors.Cols(),
|
||||
row_vectors.Stride());
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ namespace gcpp {
|
|||
|
||||
// Page-aligned on NUMA systems so we can bind to a NUMA node. This also allows
|
||||
// moving because it is a typedef to `std::unique_ptr`.
|
||||
using PoolPtr = AlignedClassPtr2<hwy::ThreadPool>;
|
||||
using PoolPtr = AlignedClassPtr<hwy::ThreadPool>;
|
||||
|
||||
// Creates a hierarchy of thread pools according to `BoundedTopology`: one with
|
||||
// a thread per enabled package; for each of those, one with a thread per
|
||||
|
|
|
|||
Loading…
Reference in New Issue