Major MatMul update, 1.9-2.3x speedup on Zen4 via bf16 mul

Supports converting all weight/activation formats to native MulT (bf16/f32)

Also:
- ConstMat/MutableMat for const correctness
- Move RowVectorBatch to allocator.h so it can be used from Matmul
- Add matmul.h so MatMulEnv can be used from Activations
- Remove kMaxThreads, detect from PerClusterPools
- Build fix: -inl.h files must be textual_hdrs, and highway.h should precede -inl.h

```
zen4 new
64, 24576, 3072, add=0, MatTA=bf16, MatTB=sfp:   616.6 GFLOPS.
64, 3072, 24576, add=0, MatTA=bf16, MatTB=sfp:   460.7 GFLOPS.
64, 24576, 3072, add=0, MatTA=f32, MatTB=sfp:    598.6 GFLOPS.
64, 3072, 24576, add=0, MatTA=f32, MatTB=sfp:    435.6 GFLOPS.

zen4 old
64, 24576, 3072, add=0, MatTA=f32, MatTB=sfp:    257.5 GFLOPS.
64, 3072, 24576, add=0, MatTA=f32, MatTB=sfp:    231.9 GFLOPS.
```

PiperOrigin-RevId: 663729812
This commit is contained in:
Jan Wassenberg 2024-08-16 07:51:40 -07:00 committed by Copybara-Service
parent 6c57feb52f
commit 301dc8067a
20 changed files with 862 additions and 687 deletions

View File

@ -20,14 +20,37 @@ licenses(["notice"])
exports_files(["LICENSE"])
cc_library(
name = "allocator",
hdrs = ["util/allocator.h"],
deps = [
"@hwy//:hwy",
],
)
cc_library(
name = "threading",
hdrs = ["util/threading.h"],
deps = [
"@hwy//:hwy",
"@hwy//:thread_pool",
"@hwy//:topology",
],
)
cc_library(
name = "ops",
hdrs = [
"ops/matmul.h",
],
textual_hdrs = [
"ops/ops-inl.h",
"ops/matmul-inl.h",
"ops/matvec-inl.h",
],
deps = [
":allocator",
":threading",
"//compression:compress",
"//compression:sfp",
"@hwy//:algo",
@ -86,6 +109,7 @@ cc_test(
tags = ["hwy_ops_test"],
deps = [
":ops",
":threading",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@hwy//:hwy",
@ -114,6 +138,7 @@ cc_library(
srcs = ["gemma/weights.cc"],
hdrs = ["gemma/weights.h"],
deps = [
":allocator",
":common",
"//compression:compress",
"//compression:io",
@ -148,16 +173,6 @@ cc_library(
],
)
cc_library(
name = "threading",
hdrs = ["util/threading.h"],
deps = [
"@hwy//:hwy",
"@hwy//:thread_pool",
"@hwy//:topology",
],
)
cc_library(
name = "gemma_lib",
srcs = [
@ -197,6 +212,7 @@ cc_library(
# Placeholder for internal file2, do not remove,
],
deps = [
":allocator",
":common",
":ops",
":tokenizer",
@ -389,11 +405,14 @@ cc_library(
hdrs = [
"backprop/activations.h",
"backprop/backward.h",
"backprop/backward-inl.h",
"backprop/forward.h",
],
textual_hdrs = [
"backprop/backward-inl.h",
"backprop/forward-inl.h",
],
deps = [
":allocator",
":common",
":gemma_lib",
":ops",
@ -413,6 +432,7 @@ cc_library(
"backprop/forward_scalar.h",
],
deps = [
":allocator",
":common",
":gemma_lib",
":prompt",
@ -467,13 +487,10 @@ cc_test(
cc_library(
name = "optimizer",
srcs = [
"backprop/optimizer.cc",
],
hdrs = [
"backprop/optimizer.h",
],
srcs = ["backprop/optimizer.cc"],
hdrs = ["backprop/optimizer.h"],
deps = [
":allocator",
":common",
":weights",
"//compression:compress",

View File

@ -103,8 +103,10 @@ set(SOURCES
ops/matmul-inl.h
ops/matvec-inl.h
ops/ops-inl.h
util/allocator.h
util/app.h
util/args.h
util/threading.h
)
if(NOT CMAKE_BUILD_TYPE)

View File

@ -20,7 +20,7 @@
#include <array>
#include "gemma/common.h" // ByteStorageT
#include "util/allocator.h" // ByteStorageT
namespace gcpp {

View File

@ -27,7 +27,6 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/activations.h" // CreateInvTimescale
#include "gemma/common.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -42,9 +41,10 @@
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#endif
#include "hwy/highway.h"
// After highway.h
#include "ops/matmul-inl.h"
#include "ops/ops-inl.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {

View File

@ -18,6 +18,8 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
@ -29,8 +31,6 @@
#include "hwy/highway.h"
// After highway.h
#include "backprop/backward-inl.h"
#include "gemma/activations.h"
#include "gemma/weights.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {

View File

@ -19,6 +19,7 @@
#include <random>
#include "gemma/common.h"
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {

View File

@ -185,6 +185,7 @@ cc_library(
name = "weights_raw",
hdrs = ["weights_raw.h"],
deps = [
"//:allocator",
"//:common",
"//compression:compress",
"@hwy//:hwy",

View File

@ -89,6 +89,22 @@ struct CompressTraits<float> {
f1 = hn::LoadU(df, in + in_ofs + N);
}
// Called by MatMul for f32 weights or activations if native
// `ReorderWidenMulAccumulate` is available.
template <class DBF16, HWY_IF_BF16_D(DBF16), class VBF16 = hn::Vec<DBF16>>
static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in,
size_t in_ofs, VBF16& v0, VBF16& v1) {
const hn::Repartition<float, decltype(dbf16)> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
const VF f0 = hn::LoadU(df, in + in_ofs + 0 * NF);
const VF f1 = hn::LoadU(df, in + in_ofs + 1 * NF);
const VF f2 = hn::LoadU(df, in + in_ofs + 2 * NF);
const VF f3 = hn::LoadU(df, in + in_ofs + 3 * NF);
v0 = hn::OrderedDemote2To(dbf16, f0, f1);
v1 = hn::OrderedDemote2To(dbf16, f2, f3);
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
@ -196,6 +212,14 @@ struct CompressTraits<hwy::bfloat16_t> {
f1 = hn::PromoteUpperTo(df, in16);
}
template <class DBF16, HWY_IF_BF16_D(DBF16)>
static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in,
size_t in_ofs, hn::Vec<DBF16>& v0,
hn::Vec<DBF16>& v1) {
v0 = hn::LoadU(dbf16, in + in_ofs);
v1 = hn::LoadU(dbf16, in + in_ofs + hn::Lanes(dbf16));
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
@ -318,14 +342,14 @@ struct CompressTraits<SfpStream> {
}
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
size_t in_ofs, hn::Vec<DF>& f0,
hn::Vec<DF>& f1) {
const hn::Twice<hn::Rebind<uint8_t, DF>> d8;
template <class D> // f32 or bf16
static HWY_INLINE void Decompress2(D d, const MatT* HWY_RESTRICT in,
size_t in_ofs, hn::Vec<D>& v0,
hn::Vec<D>& v1) {
const hn::Twice<hn::Rebind<uint8_t, D>> d8;
using V8 = hn::Vec<decltype(d8)>;
const V8 packed = hn::LoadU(d8, &in->byte + in_ofs);
SfpCodec::Dec2F(df, packed, f0, f1);
SfpCodec::Dec2(d, packed, v0, v1);
}
template <class D, typename OutT>

View File

@ -533,7 +533,7 @@ class SfpCodec {
template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Twice<hn::Rebind<uint8_t, DF>>>>
static HWY_INLINE void Dec2F(DF df, V8 packed, hn::Vec<DF>& f0,
static HWY_INLINE void Dec2(DF df, V8 packed, hn::Vec<DF>& f0,
hn::Vec<DF>& f1) {
const hn::Rebind<hwy::bfloat16_t, DF> dbf;
using VBF = hn::Vec<decltype(dbf)>;
@ -543,6 +543,13 @@ class SfpCodec {
f1 = hn::PromoteTo(df, bf1);
}
template <class DBF16, HWY_IF_BF16_D(DBF16),
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF16>>>
static HWY_INLINE void Dec2(DBF16 dbf16, V8 packed, hn::Vec<DBF16>& bf0,
hn::Vec<DBF16>& bf1) {
Dec2B(dbf16, packed, bf0, bf1);
}
private:
// Wrappers to avoid code duplication across float/bf16 input types and
// the main loop/remainder.

View File

@ -26,8 +26,8 @@
#include <random>
#include "gemma/common.h"
#include "gemma/configs.h"
#include "util/allocator.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

View File

@ -20,51 +20,14 @@
#include <cmath>
#include "gemma/common.h" // kMaxThreads - TODO: remove
#include "hwy/aligned_allocator.h"
#include "ops/matmul.h" // MatMulEnv
#include "util/allocator.h" // RowVectorBatch
#include "util/threading.h"
#include "hwy/base.h" // HWY_DASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x len) matrix.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() : batch_size_(0), len_(0) {}
// Main ctor, called from Activations::Allocate.
RowVectorBatch(size_t batch_size, size_t len)
: batch_size_(batch_size), len_(len) {
mem_ = hwy::AllocateAligned<T>(batch_size * len);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return batch_size_; }
size_t Len() const { return len_; }
// Returns the given row vector of length `Len()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < batch_size_);
return mem_.get() + batch_idx * len_;
}
// For MatMul or other operations that process the entire batch at once.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
private:
hwy::AlignedFreeUniquePtr<T[]> mem_;
size_t batch_size_; // rows in the matrix
size_t len_; // columns in the matrix = vector length
};
struct Activations {
RowVectorBatch<float> x; // input
RowVectorBatch<float> q; // query, also KV if MHA.
@ -94,9 +57,11 @@ struct Activations {
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
// per-thread storage.
// TODO: remove once MatVec is gone.
// TODO: remove once MatVec is no longer used.
RowVectorBatch<float> even_odd;
MatMulEnv env;
// Multi-Head Attention?
template <class TConfig>
static constexpr bool IsMHA() {
@ -126,7 +91,7 @@ struct Activations {
}
template <class TConfig>
void Allocate(size_t batch_size) {
void Allocate(size_t batch_size, PerClusterPools& pools) {
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kQKVDim = TConfig::kQKVDim;
constexpr size_t kHeads = TConfig::kHeads;
@ -158,7 +123,10 @@ struct Activations {
inv_timescale = CreateInvTimescale<TConfig>();
even_odd = RowVectorBatch<float>(1, kModelDim * kMaxThreads);
const size_t num_lp = pools.NumLP();
even_odd = RowVectorBatch<float>(1, kModelDim * num_lp);
env = MatMulEnv(pools);
}
};

View File

@ -18,24 +18,15 @@
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdint.h>
#include <string>
#include "compression/compress.h"
#include "gemma/configs.h" // IWYU pragma: export
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ConvertScalarTo
namespace gcpp {
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
template <typename T>
ByteStorageT AllocateSizeof() {
return hwy::AllocateAligned<uint8_t>(sizeof(T));
}
// Model variants: see configs.h for details. When adding a new one, also
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
enum class Model {

View File

@ -36,14 +36,8 @@ namespace gcpp {
#define GEMMA_TOPK 1
#endif // !GEMMA_TOPK
// Allow changing upper bound on threads as a compiler flag
#ifndef GEMMA_MAX_THREADS
#define GEMMA_MAX_THREADS 128
#endif // !GEMMA_MAX_THREADS
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK;
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
using EmbedderInputT = hwy::bfloat16_t;

View File

@ -41,6 +41,7 @@
#include "ops/matmul-inl.h"
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
#include "util/allocator.h"
#include "util/threading.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
@ -73,9 +74,10 @@ template <class TConfig>
HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t num_tokens, size_t layer,
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Griffin");
KVCache& kv_cache = kv_caches[0];
hwy::ThreadPool& pool = activations.env.Pool();
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
static constexpr size_t kModelDim = TConfig::kModelDim;
@ -240,12 +242,12 @@ class GemmaAttention {
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
const auto pre_att_rms_out =
MakeMat(activations_.pre_att_rms_out.All(), kModelDim);
MatMul_4x4</*kAdd=*/false>(
ConstMat(activations_.pre_att_rms_out.All(), kModelDim);
MatMul</*kAdd=*/false>(
num_interleaved, pre_att_rms_out,
MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim),
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
MakeMat(activations_.q.All(), kHeads * kQStride), pool_);
ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim),
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, activations_.env,
MutableMat(activations_.q.All(), kHeads * kQStride));
if constexpr (kIsMHA) {
static_assert(TConfig::kInterleaveQKV, "MHA implies interleaved");
@ -259,12 +261,13 @@ class GemmaAttention {
queries_pos_[0] * kCachePosSize + layer_ * kCacheLayerSize;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
MatMul_4x4</*kAdd=*/false>(
MatMul</*kAdd=*/false>(
num_tokens_, pre_att_rms_out,
MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim,
ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim,
kHeads * kQKVDim * kModelDim),
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool_);
activations_.env,
MutableMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize));
} else {
// Proceed row by row because there will be wraparound.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
@ -430,19 +433,18 @@ class GemmaAttention {
// Thus the [num_interleaved, kModelDim] matmul output is the sum over
// heads. Compare gemma/modules.py:
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
MatMul_4x4<kAdd>(
num_interleaved, MakeMat(activations_.att_out.All(), kHeads * kQKVDim),
MakeMat(layer_weights_.att_weights.data(), kHeads * kQKVDim),
layer_weights_.attn_vec_einsum_w.scale(), bias,
MakeMat(activations_.att_sums.All(), kModelDim), pool_);
MatMul<kAdd>(
num_interleaved, ConstMat(activations_.att_out.All(), kHeads * kQKVDim),
ConstMat(layer_weights_.att_weights.data(), kHeads * kQKVDim),
layer_weights_.attn_vec_einsum_w.scale(), bias, activations_.env,
MutableMat(activations_.att_sums.All(), kModelDim));
}
public:
GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
hwy::ThreadPool& pool)
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
: queries_pos_(queries_pos),
num_queries_(queries_pos.size()),
num_tokens_(num_tokens),
@ -451,7 +453,7 @@ class GemmaAttention {
layer_weights_(*layer_weights),
div_seq_len_(div_seq_len),
kv_caches_(kv_caches),
pool_(pool) {
pool_(activations.env.Pool()) {
HWY_DASSERT(num_queries_ <= kv_caches_.size());
}
@ -480,17 +482,17 @@ HWY_NOINLINE void Attention(LayerAttentionType type,
size_t layer, Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
const KVCaches& kv_caches) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention<TConfig>(queries_pos, num_tokens, layer, activations,
layer_weights, div_seq_len, kv_caches, pool)();
layer_weights, div_seq_len, kv_caches)();
} else {
// Only reached if the model is Griffin. `if constexpr` prevents generating
// this code for non-Griffin models.
if constexpr (TConfig::kGriffinLayers > 0) {
HWY_ASSERT(queries_pos.size() == 1);
GriffinRecurrent<TConfig>(queries_pos[0], num_tokens, layer, activations,
layer_weights, kv_caches, pool);
layer_weights, kv_caches);
}
}
}
@ -510,8 +512,7 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
template <class TConfig>
HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
const CompressedLayer<TConfig>* layer_weights,
hwy::ThreadPool& pool) {
const CompressedLayer<TConfig>* layer_weights) {
PROFILER_ZONE("Gen.FFW");
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
@ -519,9 +520,9 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
// MatMul expects col-major B, which is what we have: kModelDim consecutive
// elements in memory, repeated kFFHiddenDim times.
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
const auto A = MakeMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
const auto B1 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim);
const auto B2 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim,
const auto A = ConstMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
const auto B1 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim);
const auto B2 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim,
kModelDim, kModelDim * kFFHiddenDim);
const float scale = layer_weights->gating_einsum_w.scale();
constexpr bool kAddBias = TConfig::kFFBiases;
@ -533,22 +534,23 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
bias2 = bias1 + kFFHiddenDim;
output_bias = layer_weights->ffw_output_biases.data_scale1();
}
auto C1 = MakeMat(activations.C1.All(), kFFHiddenDim);
auto C2 = MakeMat(activations.C2.All(), kFFHiddenDim);
auto C1 = MutableMat(activations.C1.All(), kFFHiddenDim);
auto C2 = MutableMat(activations.C2.All(), kFFHiddenDim);
// Will go through GELU.
MatMul_4x4<kAddBias>(num_interleaved, A, B1, scale, bias1, C1, pool);
MatMul<kAddBias>(num_interleaved, A, B1, scale, bias1, activations.env, C1);
// What to multiply by.
MatMul_4x4<kAddBias>(num_interleaved, A, B2, scale, bias2, C2, pool);
MatMul<kAddBias>(num_interleaved, A, B2, scale, bias2, activations.env, C2);
// Activation (Gelu) and multiply by gate. Store activations in C1.
Activation<TConfig>(C1.ptr, C2.ptr, kFFHiddenDim * num_interleaved);
// Hidden layer -> output layer.
MatMul_4x4<kAddBias>(num_interleaved, C1,
MakeMat(layer_weights->linear_w.data(), kFFHiddenDim),
MatMul<kAddBias>(num_interleaved, ConstMat(C1),
ConstMat(layer_weights->linear_w.data(), kFFHiddenDim),
layer_weights->linear_w.scale(), output_bias,
MakeMat(activations.ffw_out.All(), kModelDim), pool);
activations.env,
MutableMat(activations.ffw_out.All(), kModelDim));
}
// `batch_idx` indicates which row of `x` to write to.
@ -594,8 +596,7 @@ template <class TConfig>
HWY_NOINLINE void TransformerLayer(
const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
hwy::ThreadPool& pool) {
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
constexpr size_t kModelDim = TConfig::kModelDim;
const size_t num_interleaved = num_tokens * queries_pos.size();
auto type = TConfig::kLayerConfig[layer];
@ -607,7 +608,7 @@ HWY_NOINLINE void TransformerLayer(
activations.pre_att_rms_out.All(), kModelDim);
Attention<TConfig>(type, queries_pos, num_tokens, layer_of_type, activations,
layer_weights, div_seq_len, kv_caches, pool);
layer_weights, div_seq_len, kv_caches);
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
activations.att_sums.All());
@ -620,7 +621,7 @@ HWY_NOINLINE void TransformerLayer(
layer_weights->pre_ffw_norm_scale.data_scale1(),
activations.bf_pre_ffw_rms_out.All(), kModelDim);
FFW<TConfig>(activations, num_interleaved, layer_weights, pool);
FFW<TConfig>(activations, num_interleaved, layer_weights);
PostNorm<TConfig>(num_interleaved, layer_weights->post_ffw_norm_scale,
activations.ffw_out.All());
@ -630,83 +631,39 @@ HWY_NOINLINE void TransformerLayer(
/*is_attention=*/false);
}
// Prefill and Transformer() advance positions in-place.
// Prefill() and Transformer() increment positions in-place.
using QueriesMutablePos = hwy::Span<size_t>;
// Batches are important for amortizing loading weights over multiple tokens.
// This is possible in prefill because we know all tokens beforehand, whereas
// decode depends on the previous output token. However, each prefill batch of a
// query requires that preceding batches already wrote to the KV cache, hence we
// sequentially loop over token batches. We can reduce the number of iterations
// by increasing the batch size, but this also increases arithmetic intensity,
// and so we are eventually compute-limited. The tensor parallelism (number of
// threads collaborating on MatMul) is also limited by the CPU topology:
// fork/join barriers are slow(er) when some threads reside in a different NUMA
// node. To allow more threads to help, we also support parallelizing over
// queries in case GenerateBatch was called.
//
// Thus we have two-level parallelism:
// - Outer: handles one 'qbatch' of entire queries. The set of outer workers
// includes the main thread because it is the one that calls `Prefill`, and is
// determined by the number of 'clusters' (shared L3 caches or sockets).
// - Inner: each `outer` worker passes `inner_pools_[outer]` to
// `TransformerLayer` for tensor-level parallelism, and processes
// `tbatch_size` tokens from a single query at a time.
//
// This class holds the thread pools and one activation per outer worker. It is
// NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt
// to their num_queries.
class PrefillState {
public:
// `tbatch_size` is the number of tokens from one query to prefill at a time.
template <class TConfig>
void Init(size_t num_queries, size_t tbatch_size, PerClusterPools& pools) {
PROFILER_ZONE("Init.Prefill");
HWY_ASSERT(num_queries != 0);
HWY_ASSERT(activations_.empty()); // only call once.
// Allocate one activation per query, not outer worker, because the common
// case is a single query. If we allocate the lesser of the two, it is
// unclear how to choose an unused activation in Prefill.
activations_.resize(num_queries);
if (num_queries == 1) {
activations_[0].Allocate<TConfig>(tbatch_size);
} else {
// Allocating in parallel can save 30 ms. We might have more workers than
// queries/tasks, so do not check the `thread` argument.
pools.Outer().Run(0, num_queries,
[this, tbatch_size](uint64_t qi, size_t /*thread*/) {
activations_[qi].Allocate<TConfig>(tbatch_size);
});
}
}
template <class TConfig>
HWY_NOINLINE void Prefill(const QueriesPromptTokens& queries_prompt,
const size_t prefill_per_query,
const QueriesMutablePos& queries_pos,
const size_t query_idx_start,
const CompressedWeights<TConfig>& weights,
const RuntimeConfig& runtime_config,
const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches, PerClusterPools& pools) {
// Populates KV cache for batches of tokens from one query at a time.
template <class TConfig>
HWY_NOINLINE void Prefill(
const QueriesPromptTokens& queries_prompt, const size_t prefill_per_query,
const QueriesMutablePos& queries_pos, const size_t query_idx_start,
const CompressedWeights<TConfig>& weights, Activations& activations,
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() == num_queries);
const size_t max_tbatch_size = activations_[0].x.BatchSize();
// For each query (parallel): an outer worker processes all its tokens.
// `qi` is relative to the batch, not the global query index.
pools.Outer().Run(
0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR {
Activations& activations = activations_[qi];
hwy::ThreadPool& inner_pool = pools.Inner(qthread);
// Batches are important for amortizing loading weights over multiple tokens.
// This is possible in prefill because we know all tokens beforehand, whereas
// decode depends on the previous output token. However, each prefill batch of
// a query requires that preceding batches already wrote to the KV cache,
// hence we sequentially loop over token batches. We can reduce the number of
// iterations by increasing the batch size, but this also increases arithmetic
// intensity, and so we are eventually compute-limited. We could devote some
// threads to parallelizing over queries, but for simplicity we assign them
// all to MatMul.
const size_t max_tbatch_size = activations.x.BatchSize();
// For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) {
// Single query at a time, so pass slices of the spans because
// GemmaAttention will only access the first KV cache and position.
KVCaches single_kv_cache(&kv_caches[qi], 1);
QueriesPos single_query_pos(&queries_pos[qi], 1);
KVCaches single_kv_cache(&kv_caches[qi], 1);
// For each batch of tokens in the query:
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
@ -725,25 +682,20 @@ class PrefillState {
const auto* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig>(single_query_pos, tbatch_size, layer,
layer_weights, activations, div_seq_len,
single_kv_cache, inner_pool);
single_kv_cache);
}
// NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = queries_pos[qi] + ti;
const int token = queries_prompt[qi][pos];
runtime_config.StreamToken(query_idx_start + qi, pos, token,
0.0f);
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
}
queries_pos[qi] += tbatch_size;
} // for tbatch_start
});
}
private:
std::vector<Activations> activations_; // One per query, filled by Init.
};
}
// Generates one token for each query. `queries_token` is the previous token
// from each query, and `queries_pos` are their position in the sequence.
@ -752,7 +704,7 @@ HWY_NOINLINE void Transformer(
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
const CompressedWeights<TConfig>& weights, Activations& activations,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
hwy::ThreadPool& pool, const LayersOutputFunc& layers_output,
const LayersOutputFunc& layers_output,
const ActivationsObserverFunc& activations_observer) {
constexpr size_t kModelDim = TConfig::kModelDim;
const size_t num_queries = queries_token.size();
@ -775,7 +727,7 @@ HWY_NOINLINE void Transformer(
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig>(queries_pos, /*num_tokens=*/1, layer,
layer_weights, activations, div_seq_len,
kv_caches, pool);
kv_caches);
if (activations_observer) {
activations_observer(queries_pos, layer, activations);
@ -880,16 +832,12 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in, const size_t query_idx_start,
const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) {
const KVCaches& kv_caches, TimingInfo& timing_info) {
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kVocabSize = TConfig::kVocabSize;
const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
// TODO: remove once all parallel sections support hierarchical parallelism.
hwy::ThreadPool& pool = pools.Inner(0);
// Copy so we can increment without requiring users to pass in a mutable span.
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
queries_pos_in.cend());
@ -930,19 +878,22 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
// Prefill stops before min_prompt_size - 1 because the last prompt token is
// the first input token for generation.
const size_t prefill_per_query = min_prompt_size - 1;
double prefill_start;
{
// TODO: move to Gemma, reuse across calls to Generate.
PrefillState prefill;
prefill.Init<TConfig>(num_queries, runtime_config.prefill_tbatch_size,
pools);
prefill_start = hwy::platform::Now();
prefill.Prefill<TConfig>(queries_prompt, prefill_per_query,
queries_mutable_pos, query_idx_start, weights,
runtime_config, div_seq_len, kv_caches, pools);
const double prefill_start = hwy::platform::Now();
// If tbatch is larger than the qbatch we already have in `activations`, then
// allocate prefill_activations, otherwise reuse.
const bool use_prefill_activations =
runtime_config.prefill_tbatch_size > activations.x.BatchSize();
Activations prefill_activations;
if (use_prefill_activations) {
prefill_activations.Allocate<TConfig>(runtime_config.prefill_tbatch_size,
activations.env.Pools());
}
Prefill<TConfig>(queries_prompt, prefill_per_query, queries_mutable_pos,
query_idx_start, weights,
use_prefill_activations ? prefill_activations : activations,
runtime_config, div_seq_len, kv_caches);
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
// queries_pos are incremented by Prefill.
}
// Storage for the last generated token from each query, passed to the next
// Transformer() call.
@ -962,18 +913,18 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
// Decode generates one token per query and increments queries_mutable_pos.
Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries),
queries_mutable_pos, weights, activations, div_seq_len,
kv_caches, pool, runtime_config.layers_output,
kv_caches, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.
bool all_queries_eos = true;
PROFILER_ZONE("Gen.Embedding");
// Compute logits from last layer activations.
MatMul_4x4</*kAdd=*/false>(
num_queries, MakeMat(activations.x.All(), kModelDim),
MakeMat(weights.embedder_input_embedding.data(), kModelDim),
MatMul</*kAdd=*/false>(
num_queries, ConstMat(activations.x.All(), kModelDim),
ConstMat(weights.embedder_input_embedding.data(), kModelDim),
weights.embedder_input_embedding.scale(), /*add=*/nullptr,
MakeMat(activations.logits.All(), kVocabSize), pool);
activations.env, MutableMat(activations.logits.All(), kVocabSize));
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
@ -1001,15 +952,16 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
constexpr size_t kNumQueries = 1;
const size_t qbatch_start = 0;
// TODO: move into Gemma?
Activations activations;
activations.Allocate<TConfig>(kNumQueries);
activations.Allocate<TConfig>(kNumQueries, pools);
const QueriesPromptTokens prompt_span(&prompt, kNumQueries);
QueriesPos pos_span(&pos, kNumQueries);
const KVCaches kv_caches{&kv_cache, kNumQueries};
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompt_span,
pos_span, qbatch_start, kv_caches, pools, timing_info);
pos_span, qbatch_start, kv_caches, timing_info);
}
template <class TConfig>
@ -1026,7 +978,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
(TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size;
Activations activations;
activations.Allocate<TConfig>(max_qbatch_size);
activations.Allocate<TConfig>(max_qbatch_size, pools);
for (size_t qbatch_start = 0; qbatch_start < num_queries;
qbatch_start += max_qbatch_size) {
@ -1038,7 +990,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
qbatch_pos, qbatch_start, qbatch_kv, pools, timing_info);
qbatch_pos, qbatch_start, qbatch_kv, timing_info);
}
}

View File

@ -21,6 +21,7 @@
#include "compression/compress.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "util/allocator.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

View File

@ -13,19 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h" // temporarily disabled
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
#include "compression/compress.h" // IWYU pragma: keep, b/conditionally used
#include "ops/matmul.h" // IWYU pragma: export
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
@ -35,6 +26,8 @@
#define THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE
#endif
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "hwy/contrib/math/math-inl.h"
@ -43,355 +36,392 @@ namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// A square kernel minimizes the ratio of loads to FMA. 4x 128-bit corresponds
// to one cache line.
constexpr size_t kRegRows = 4;
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
// loads, we reuse the same A row for several B columns, which are also loaded
// once for several rows of C. Thus we produce one 'tile' of C at a time of
// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are
// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4
// enables the `StoreInterleaved4` transpose in `AddHorizontalSums`. We assume
// and verify that `C.cols % kRegCols == 0`.
constexpr size_t kRegCols = 4;
// Initializes a reg-tile of C: if kAdd, `add[add_ofs + c]`; otherwise 0.
// `add` has no scale, and if `kAdd` is a row vector with A.cols entries,
// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because
// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile.
// In general, `batch_size` (C rows) is not a multiple of `kRegRows`. Thus
// functions that load or store a tile are parameterized on `kNumRows`, which is
// generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
constexpr size_t kRegRows = kRegCols;
// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are
// more efficient than f32 * f32 + f32 because they process twice as many lanes
// at a time. Any combination of A and B can be bf16: activations may already be
// bf16, and weights can be decompressed to bf16.
//
// The corresponding op is `ReordenWidenMulAccumulate`, and it is always
// supported, but only useful if it returns a single vector of pairwise sums
// `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReordenWidenMulAccumulate`
// insteads return `a[1] * b[1]` in its `sum1` output. We cannot afford to keep
// a `sum1` for each of the `kRegRows * kRegCols` C vectors, and it would be
// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B
// to bf16 if the native op is available. This will actually demote f32
// activations to bf16. Otherwise, we decompress to f32 and use normal FMA.
using MulT = hwy::If<HWY_NATIVE_DOT_BF16, BF16, float>;
// Loads two vectors at a time with element type MulT from a row of transposed
// B. Called in a loop over col_ab. No bounds checking because `kRow` is
// actually from B columns, which we checked is a multiple of `kRegCols`.
template <size_t kRow, typename MatTB>
class BRow {
static_assert(kRow < kRegRows); // which unrolled instance we are
using TraitsB = CompressTraits<MatTB>;
public:
BRow(const Mat<const MatTB>& B, size_t row_b)
: B_(B.ptr), B_ofs_(B.Row(row_b + kRow)) {}
template <class DM, class VM = hn::Vec<DM>>
HWY_INLINE void Load2(DM d, size_t col_ab, VM& b0, VM& b1) const {
static_assert(hwy::IsSame<hn::TFromD<DM>, MulT>());
TraitsB::Decompress2(d, B_, B_ofs_ + col_ab, b0, b1);
}
private:
const MatTB* HWY_RESTRICT B_;
const size_t B_ofs_;
};
// Loads *two* row vectors from A via `Decompress2`, multiplies element-wise
// with `kRegRows` x 2 row vectors from transposed B, and adds them to
// `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a subset of
// the terms of the dot products that make up the MatMul result at `r,c`.
// No-op for the bottom-most tile where kRow >= kNumRows.
//
// This approach is atypical because it requires a horizontal sum, for which we
// introduce a fast and new(?) vector-length agnostic 'transpose', see
// `AddHorizontalSums`. Most MatMul instead broadcast one element from A and
// multiply with one element from N columns in B to obtain N columns of C.
// This is a poor fit for our setting:
// - `CompressTraits` decompresses two vectors at a time;
// - B is column-major, so unit-stride SIMD loads return a column, not values
// from different columns, i.e. a row.
// Both could be fixed in a packing stage, which is not implemented yet, and
// might not be necessary otherwise. However, `ReorderWidenMulAccumulate` is
// important for bf16 performance and incompatible with the conventional
// approach, because its pairwise adds would add together unrelated terms.
// By contrast, pairwise adds are fine when our C lanes are the terms of a
// single dot product, which can be reordered or pre-reduced.
template <size_t kRow, typename MatTA>
class ALoadAccumulate {
static_assert(kRow < kRegRows); // which unrolled instance we are
using TraitsA = CompressTraits<MatTA>;
public:
ALoadAccumulate(const Mat<const MatTA>& A, size_t row_ac)
: A_(A.ptr), A_ofs_(A.Row(row_ac + kRow)) {}
// First iteration, col_ab = 0: initialize C0..3 instead of updating them.
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
HWY_INLINE void First(DM dm, //
const VM b00, const VM b01, const VM b10, const VM b11,
const VM b20, const VM b21, const VM b30, const VM b31,
VM& C0, VM& C1, VM& C2, VM& C3) const {
static_assert(kNumRows <= kRegRows); // How many rows actually present
if constexpr (kRow < kNumRows) {
VM a0, a1;
TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1);
static_assert(kRegCols == 4);
C0 = hn::Mul(a0, b00);
C1 = hn::Mul(a0, b10);
C2 = hn::Mul(a0, b20);
C3 = hn::Mul(a0, b30);
C0 = hn::MulAdd(a1, b01, C0);
C1 = hn::MulAdd(a1, b11, C1);
C2 = hn::MulAdd(a1, b21, C2);
C3 = hn::MulAdd(a1, b31, C3);
}
}
// Same as above, only called if MulT == BF16.
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>,
HWY_IF_BF16_D(DM), class DF = hn::Repartition<float, DM>,
class VF = hn::Vec<DF>>
HWY_INLINE void First(DM dm, //
const VM b00, const VM b01, const VM b10, const VM b11,
const VM b20, const VM b21, const VM b30, const VM b31,
VF& C0, VF& C1, VF& C2, VF& C3) const {
static_assert(kNumRows <= kRegRows); // How many rows actually present
if constexpr (kRow < kNumRows) {
VM a0, a1;
TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1);
const DF df;
VF unused_sum1 = hn::Zero(df);
static_assert(kRegCols == 4);
C0 = hn::WidenMulPairwiseAdd(df, a0, b00);
C1 = hn::WidenMulPairwiseAdd(df, a0, b10);
C2 = hn::WidenMulPairwiseAdd(df, a0, b20);
C3 = hn::WidenMulPairwiseAdd(df, a0, b30);
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
// Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
}
}
// Non-first iteration: accumulate into C0..3.
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01,
const VM b10, const VM b11, const VM b20, const VM b21,
const VM b30, const VM b31, VM& C0, VM& C1, VM& C2,
VM& C3) const {
static_assert(kNumRows <= kRegRows); // How many rows actually present
HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration.
if constexpr (kRow < kNumRows) {
VM a0, a1;
TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
static_assert(kRegCols == 4);
C0 = hn::MulAdd(a0, b00, C0);
C1 = hn::MulAdd(a0, b10, C1);
C2 = hn::MulAdd(a0, b20, C2);
C3 = hn::MulAdd(a0, b30, C3);
C0 = hn::MulAdd(a1, b01, C0);
C1 = hn::MulAdd(a1, b11, C1);
C2 = hn::MulAdd(a1, b21, C2);
C3 = hn::MulAdd(a1, b31, C3);
}
}
// Same as above, only called if MulT == BF16.
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>,
HWY_IF_BF16_D(DM), class DF = hn::Repartition<float, DM>,
class VF = hn::Vec<DF>>
HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01,
const VM b10, const VM b11, const VM b20, const VM b21,
const VM b30, const VM b31, VF& C0, VF& C1, VF& C2,
VF& C3) const {
static_assert(kNumRows <= kRegRows); // How many rows actually present
HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration.
if constexpr (kRow < kNumRows) {
VM a0, a1;
TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
const DF df;
hn::Vec<DF> unused_sum1 = hn::Zero(df);
static_assert(kRegCols == 4);
C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1);
C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1);
C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1);
C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1);
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
// Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
}
}
private:
const MatTA* HWY_RESTRICT A_;
const size_t A_ofs_;
}; // ALoadAccumulate
// Sets a `kRegRows` x `kRegCols` tile of C to `add[add_ofs + c]` if kAdd,
// otherwise 0.
// `add` has no scale and is a row vector with A.cols entries if `kAdd`,
// otherwise nullptr. In the latter case, adding `add_ofs` to it would be UB,
// hence we pass it as a separate argument.
template <size_t kNumRows, bool kAdd>
HWY_INLINE void InitC(const float* HWY_RESTRICT add, size_t add_ofs,
float* HWY_RESTRICT pos_c, size_t stride_c) {
const hn::FixedTag<float, kRegCols> d4;
for (size_t r = 0; r < HWY_MIN(kNumRows, kRegRows); ++r) {
for (size_t c = 0; c < kRegCols; ++c) {
if constexpr (kAdd) {
pos_c[r * stride_c + c] = add[add_ofs + c];
hn::StoreU(hn::LoadU(d4, add + add_ofs), d4, pos_c + r * stride_c);
} else {
pos_c[r * stride_c + c] = 0.0f;
}
hn::StoreU(hn::Zero(d4), d4, pos_c + r * stride_c);
}
}
}
// c## are partial sums of the products of A and B; their horizontal sums are
// the final matmul result, stored in C, which is always f32.
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void AddHorizontalSums(DF df, float scale, //
VF c00, VF c01, VF c02, VF c03, //
VF c10, VF c11, VF c12, VF c13, //
VF c20, VF c21, VF c22, VF c23, //
VF c30, VF c31, VF c32, VF c33, //
float* HWY_RESTRICT tile_c, size_t stride_c) {
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
// expensive, but only a fraction of the A.cols/N FMAs.
// TODO: 4x4 transpose, then 128-bit vector FMA?
tile_c[stride_c * 0 + 0] += scale * hn::ReduceSum(df, c00);
tile_c[stride_c * 0 + 1] += scale * hn::ReduceSum(df, c01);
tile_c[stride_c * 0 + 2] += scale * hn::ReduceSum(df, c02);
tile_c[stride_c * 0 + 3] += scale * hn::ReduceSum(df, c03);
if (kNumRows == 1) return;
tile_c[stride_c * 1 + 0] += scale * hn::ReduceSum(df, c10);
tile_c[stride_c * 1 + 1] += scale * hn::ReduceSum(df, c11);
tile_c[stride_c * 1 + 2] += scale * hn::ReduceSum(df, c12);
tile_c[stride_c * 1 + 3] += scale * hn::ReduceSum(df, c13);
if (kNumRows == 2) return;
tile_c[stride_c * 2 + 0] += scale * hn::ReduceSum(df, c20);
tile_c[stride_c * 2 + 1] += scale * hn::ReduceSum(df, c21);
tile_c[stride_c * 2 + 2] += scale * hn::ReduceSum(df, c22);
tile_c[stride_c * 2 + 3] += scale * hn::ReduceSum(df, c23);
if (kNumRows == 3) return;
tile_c[stride_c * 3 + 0] += scale * hn::ReduceSum(df, c30);
tile_c[stride_c * 3 + 1] += scale * hn::ReduceSum(df, c31);
tile_c[stride_c * 3 + 2] += scale * hn::ReduceSum(df, c32);
tile_c[stride_c * 3 + 3] += scale * hn::ReduceSum(df, c33);
}
// Wrapper to simplify call sites. T can be const or non-const.
template <typename T>
struct Mat {
bool NotEmpty() const {
return ptr != nullptr && cols != 0 && stride >= cols;
// Accumulates into a tile of C.
template <size_t kNumRows>
class AddHorizontalSums {
// These helper functions hoist if() out of the main code below. They have no
// effect if kRow >= kNumRows.
template <size_t kRow, class DF, class VF = hn::Vec<DF>>
static void MaybeStoreInterleaved4(DF df, size_t N, VF Cr0, VF Cr1, VF Cr2,
VF Cr3, float* HWY_RESTRICT buf) {
if constexpr (kRow < kNumRows) {
hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, df, buf + 4 * kRow * N);
}
}
size_t Row(size_t r) const { return ofs + stride * r; }
T* HWY_RESTRICT ptr;
size_t cols;
// Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4.
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static V4 MaybeLoad(D4 df, size_t N, const float* HWY_RESTRICT buf) {
if constexpr (kRow < kNumRows) {
return hn::Load(df, buf + 4 * kRow * N);
} else {
return hn::Zero(df);
}
}
// elements between rows, which is typically the same as `cols`.
size_t stride;
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static V4 MaybeAdd(D4 df, size_t N, V4 sum, const float* HWY_RESTRICT buf) {
if constexpr (kRow < kNumRows) {
return hn::Add(sum, hn::Load(df, buf + 4 * kRow * N));
} else {
return sum;
}
}
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic.
size_t ofs;
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static void MaybeMulAdd(D4 df, V4 sum, V4 scale, float* HWY_RESTRICT tile_c,
const size_t stride_c) {
if constexpr (kRow < kNumRows) {
const V4 prev_c = hn::LoadU(df, tile_c + kRow * stride_c);
hn::StoreU(hn::MulAdd(sum, scale, prev_c), df, tile_c + kRow * stride_c);
}
}
public:
// Adds the contribution from `Crc` accumulators to the 4x4 tile of C whose
// top left is `tile_c`, after multiplying by `scale`, which is the product of
// the scales of A and B. C is always f32 to ensure sufficient precision.
//
// `Crc` are the 16 combinations of an A row vector indexed by `r`, times a
// B column vector indexed by `c`. Their elements are thus a subset of the
// terms of the dot product constituting the final `C[r, c]` result. Thus we
// compute the horizontal sums of each `Crc`. The elements may be permuted
// because we multiply bf16 via `ReorderWidenMulAccumulate`, but this does
// not change their horizontal sum. `buf` is thread-local space for 16 `VF`.
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void operator()(DF df, float scale, //
VF C00, VF C01, VF C02, VF C03, //
VF C10, VF C11, VF C12, VF C13, //
VF C20, VF C21, VF C22, VF C23, //
VF C30, VF C31, VF C32, VF C33, //
float* HWY_RESTRICT buf,
float* HWY_RESTRICT tile_c,
size_t stride_c) const {
const size_t N = hn::Lanes(df);
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
// log(N) operations for vectors of length N. Because kRegCols == 4, we can
// instead use `StoreInterleaved4` for a vector length-agnostic 'transpose':
// `buf[0, 4 * N)` holds C00[0], C01[0], C02[0], C03[0],
// C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], C03[N-1].
MaybeStoreInterleaved4<0>(df, N, C00, C01, C02, C03, buf);
MaybeStoreInterleaved4<1>(df, N, C10, C11, C12, C13, buf);
MaybeStoreInterleaved4<2>(df, N, C20, C21, C22, C23, buf);
MaybeStoreInterleaved4<3>(df, N, C30, C31, C32, C33, buf);
// Adding N consecutive V4 yields four horizontal sums of Cr0, Cr1, Cr2, Cr3
// in the elements of one V4. We have four independent rows `r`, hence the
// code is effectively unrolled, which increases throughput.
const hn::FixedTag<float, 4> d4;
using V4 = hn::Vec<decltype(d4)>;
V4 sum0 = MaybeLoad<0>(d4, N, buf);
V4 sum1 = MaybeLoad<1>(d4, N, buf);
V4 sum2 = MaybeLoad<2>(d4, N, buf);
V4 sum3 = MaybeLoad<3>(d4, N, buf);
for (size_t i = 1; i < N; ++i) {
sum0 = MaybeAdd<0>(d4, N, sum0, buf + 4 * i);
sum1 = MaybeAdd<1>(d4, N, sum1, buf + 4 * i);
sum2 = MaybeAdd<2>(d4, N, sum2, buf + 4 * i);
sum3 = MaybeAdd<3>(d4, N, sum3, buf + 4 * i);
}
// Scale, then store to four elements per row of `tile_c`.
const V4 vscale = hn::Set(d4, scale);
MaybeMulAdd<0>(d4, sum0, vscale, tile_c, stride_c);
MaybeMulAdd<1>(d4, sum1, vscale, tile_c, stride_c);
MaybeMulAdd<2>(d4, sum2, vscale, tile_c, stride_c);
MaybeMulAdd<3>(d4, sum3, vscale, tile_c, stride_c);
}
};
template <typename T>
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
size_t ofs = 0) {
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
}
template <typename T>
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols) {
return MakeMat(ptr, cols, cols);
}
// Inner loop of the kernel, called once per kRegRows. c[r] += a[c] * b[r,c].
// The col_ab loop is unrolled 2x, so we have a0/a1 and b00/b01 etc.
template <class VF>
HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
const VF& b01, const VF& b10, const VF& b11,
const VF& b20, const VF& b21, const VF& b30,
const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) {
c0 = hn::MulAdd(a0, b00, c0);
c1 = hn::MulAdd(a0, b10, c1);
c2 = hn::MulAdd(a0, b20, c2);
c3 = hn::MulAdd(a0, b30, c3);
c0 = hn::MulAdd(a1, b01, c0);
c1 = hn::MulAdd(a1, b11, c1);
c2 = hn::MulAdd(a1, b21, c2);
c3 = hn::MulAdd(a1, b31, c3);
}
// Special case for the first iteration: c## are zero, so skip the first add.
template <class VF>
HWY_INLINE void FirstTileRow(const VF& a0, const VF& a1, const VF& b00,
const VF& b01, const VF& b10, const VF& b11,
const VF& b20, const VF& b21, const VF& b30,
const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) {
c0 = hn::Mul(a0, b00);
c1 = hn::Mul(a0, b10);
c2 = hn::Mul(a0, b20);
c3 = hn::Mul(a0, b30);
c0 = hn::MulAdd(a1, b01, c0);
c1 = hn::MulAdd(a1, b11, c1);
c2 = hn::MulAdd(a1, b21, c2);
c3 = hn::MulAdd(a1, b31, c3);
}
#undef GEMMA_NATIVE_BF16
#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
defined(HWY_TARGET_TOGGLE))
#define GEMMA_NATIVE_BF16 1
#else
#define GEMMA_NATIVE_BF16 0
#endif
#if GEMMA_NATIVE_BF16
// Specializations for f32 += bf16 * bf16 that avoid promoting to f32.
// Inner loop as above, but not unrolled. c[r] += a * b[r].
template <class DF, class VF = hn::Vec<DF>,
class VBF16 = hn::Vec<hn::Repartition<BF16, DF>>>
HWY_INLINE void UpdateTileRow(DF df, const VBF16& a, const VBF16& b0,
const VBF16& b1, const VBF16& b2, const VBF16& b3,
VF& c0, VF& c1, VF& c2, VF& c3) {
DF df;
VF unused_sum1 = hn::Zero(df);
c0 = hn::ReorderWidenMulAccumulate(df, a, b0, c0, unused_sum1);
c1 = hn::ReorderWidenMulAccumulate(df, a, b1, c1, unused_sum1);
c2 = hn::ReorderWidenMulAccumulate(df, a, b2, c2, unused_sum1);
c3 = hn::ReorderWidenMulAccumulate(df, a, b3, c3, unused_sum1);
// Ensure sum1 was indeed unused.
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
}
// Special case for the first iteration: c## are zero, so skip the first add.
template <class DF, class VF = hn::Vec<DF>,
class VBF16 = hn::Vec<hn::Repartition<BF16, DF>>>
HWY_INLINE void FirstTileRow(DF df, const VBF16& a, const VBF16& b0,
const VBF16& b1, const VBF16& b2, const VBF16& b3,
VF& c0, VF& c1, VF& c2, VF& c3) {
c0 = hn::WidenMulPairwiseAdd(df, a, b0);
c1 = hn::WidenMulPairwiseAdd(df, a, b1);
c2 = hn::WidenMulPairwiseAdd(df, a, b2);
c3 = hn::WidenMulPairwiseAdd(df, a, b3);
}
template <size_t kNumRows, bool kAdd>
HWY_INLINE void MatMulTile(const Mat<const BF16>& A, const Mat<const BF16>& B,
const size_t row_ac, const size_t row_b_col_c,
const float scale, const float* HWY_RESTRICT add,
const Mat<float>& C) {
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
// bf16 vectors.
const hn::Repartition<BF16, decltype(df)> d;
const size_t N = Lanes(d);
using V = hn::Vec<decltype(d)>;
V b0, b1, b2, b3; // one from each row
VF c00, c01, c02, c03;
VF c10, c11, c12, c13;
VF c20, c21, c22, c23;
VF c30, c31, c32, c33;
const BF16* HWY_RESTRICT A_tile = A.ptr + A.Row(row_ac);
const BF16* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c);
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c;
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.stride);
size_t col_ab = 0;
// First iteration initializes the c## vectors.
{
b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
{
const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
FirstTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03);
}
if constexpr (kNumRows > 1) {
const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
FirstTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13);
}
if constexpr (kNumRows > 2) {
const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
FirstTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23);
}
if constexpr (kNumRows == 3) {
const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
FirstTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33);
}
}
// Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (col_ab += N; col_ab < A.cols; col_ab += N) {
b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
{
const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
UpdateTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03);
}
if constexpr (kNumRows > 1) {
const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
UpdateTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13);
}
if constexpr (kNumRows > 2) {
const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
UpdateTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23);
}
if constexpr (kNumRows == 3) {
const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
UpdateTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33);
}
}
AddHorizontalSums<kNumRows>(df, scale, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33, C_tile,
C.stride);
}
#endif // GEMMA_NATIVE_BF16
// Streams a `(kNumRows, 4)` strip of `A` and the transposed `B`, then writes a
// finished tile of `C`.
// General case: uses CompressTraits to load from A and B.
// Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a
// *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c).
// TODO: loop over sections instead of full rows and accumulate into `tile_c`.
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
HWY_INLINE void MatMulTile(const Mat<MatTA>& A, const Mat<MatTB>& B,
HWY_INLINE void MatMulTile(const Mat<const MatTA>& A, const Mat<const MatTB>& B,
const size_t row_ac, const size_t row_b_col_c,
const float scale, const float* HWY_RESTRICT add,
const Mat<float>& C) {
using TraitsA = CompressTraits<hwy::RemoveConst<MatTA>>;
using TraitsB = CompressTraits<hwy::RemoveConst<MatTB>>;
float* HWY_RESTRICT buf, const Mat<float>& C) {
// For 'decompressing' A and B into BF16 or float.
const hn::ScalableTag<MulT> dm;
using VM = hn::Vec<decltype(dm)>;
const size_t NM = hn::Lanes(dm);
const hn::ScalableTag<float> d32;
const size_t N = hn::Lanes(d32);
using V = hn::Vec<decltype(d32)>;
V b00, b01, b10, b11, b20, b21, b30, b31; // two from each row
V c00, c01, c02, c03;
V c10, c11, c12, c13;
V c20, c21, c22, c23;
V c30, c31, c32, c33;
static_assert(kRegRows == 4);
const BRow<0, MatTB> b_row0(B, row_b_col_c);
const BRow<1, MatTB> b_row1(B, row_b_col_c);
const BRow<2, MatTB> b_row2(B, row_b_col_c);
const BRow<3, MatTB> b_row3(B, row_b_col_c);
const size_t A_ofs = A.Row(row_ac);
const size_t B_ofs = B.Row(row_b_col_c);
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac);
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac);
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac);
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac);
const hn::Repartition<float, decltype(dm)> df;
using VF = hn::Vec<decltype(df)>;
VF C00, C01, C02, C03;
VF C10, C11, C12, C13;
VF C20, C21, C22, C23;
VF C30, C31, C32, C33;
{ // First iteration initializes the `Crc` vectors.
VM b00, b01, b10, b11, b20, b21, b30, b31;
b_row0.Load2(dm, /*col_ab=*/0, b00, b01);
b_row1.Load2(dm, /*col_ab=*/0, b10, b11);
b_row2.Load2(dm, /*col_ab=*/0, b20, b21);
b_row3.Load2(dm, /*col_ab=*/0, b30, b31);
a_row0.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
C00, C01, C02, C03);
a_row1.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
C10, C11, C12, C13);
a_row2.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
C20, C21, C22, C23);
a_row3.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
C30, C31, C32, C33);
}
// `2 * NM` per iteration because `Load2` returns two vectors.
HWY_UNROLL(1)
for (size_t col_ab = 2 * NM; col_ab <= A.cols - 2 * NM; col_ab += 2 * NM) {
VM b00, b01, b10, b11, b20, b21, b30, b31;
b_row0.Load2(dm, col_ab, b00, b01);
b_row1.Load2(dm, col_ab, b10, b11);
b_row2.Load2(dm, col_ab, b20, b21);
b_row3.Load2(dm, col_ab, b30, b31);
a_row0.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C00, C01, C02, C03);
a_row1.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C10, C11, C12, C13);
a_row2.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C20, C21, C22, C23);
a_row3.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
b30, b31, C30, C31, C32, C33);
}
// TODO: hoist into outer loop.
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c;
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.stride);
// Loop over columns of A and columns of the transposed B, in steps of 2*N
// (since we are decoding consecutive bytes at each iteration).
// Top-left of tile is (row_ac, col_ab) for A, and (row_b_col_c,
// col_ab) for B. First iteration initializes the c## vectors.
size_t col_ab = 0;
{
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31);
{
V a0, a1;
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1);
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
c02, c03);
}
if constexpr (kNumRows > 1) {
V a0, a1;
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1);
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
c12, c13);
}
if constexpr (kNumRows > 2) {
V a0, a1;
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1);
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
c22, c23);
}
if constexpr (kNumRows > 3) {
V a0, a1;
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1);
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
c32, c33);
}
}
// Main loop: accumulates into the c## vectors.
HWY_UNROLL(1)
for (col_ab += 2 * N; col_ab <= A.cols - 2 * N; col_ab += 2 * N) {
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21);
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31);
{
V a0, a1;
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1);
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
c02, c03);
}
if constexpr (kNumRows > 1) {
V a0, a1;
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1);
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
c12, c13);
}
if constexpr (kNumRows > 2) {
V a0, a1;
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1);
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
c22, c23);
}
if constexpr (kNumRows > 3) {
V a0, a1;
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1);
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
c32, c33);
}
}
AddHorizontalSums<kNumRows>(d32, scale, c00, c01, c02, c03, c10, c11, c12,
c13, c20, c21, c22, c23, c30, c31, c32, c33,
C_tile, C.stride);
AddHorizontalSums<kNumRows>()(df, scale, C00, C01, C02, C03, C10, C11, C12,
C13, C20, C21, C22, C23, C30, C31, C32, C33,
buf, C_tile, C.stride);
}
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
@ -402,28 +432,28 @@ HWY_INLINE void MatMulTile(const Mat<MatTA>& A, const Mat<MatTB>& B,
//
// `scale` allows expanding the smaller range of `SfpStream` to the original
// values. When `A` and/or `B` are from CompressedArray, `scale` should be the
// product of their `.scale()` values.
// product of their `.scale()` values, otherwise 1.0f.
//
// If `kAdd` is true, the row-vector `add` is added to each row of `C`,
// otherwise `add` is ignored and can be nullptr. A scale for `add` is not
// supported, so make sure its scale is 1.
//
// `C` is a row-major matrix of size `(batch_size, C.cols)`.
// Writes 4x4 tiles of C in parallel using a work-stealing thread pool.
// Typically batch_size is 1..512, A.cols and C.cols are 3k or 24k.
//
// Updates 4x4 tiles of C in parallel using a work-stealing thread pool.
// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k.
// Must not be called concurrently with the same `env`.
template <bool kAdd, typename MatTA, typename MatTB>
HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
const Mat<MatTB>& B, const float scale,
const float* HWY_RESTRICT add, const Mat<float>& C,
hwy::ThreadPool& pool) {
HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
const Mat<const MatTB>& B, const float scale,
const float* HWY_RESTRICT add, MatMulEnv& env,
const Mat<float>& C) {
// PROFILER_ZONE("Matmul");
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty());
HWY_DASSERT(A.cols == B.cols);
// Use float instead of MatTA/MatTB because we decompress to float here.
const size_t N = hn::Lanes(hn::ScalableTag<float>());
(void)N;
HWY_DASSERT(A.cols % (N * 2) == 0); // For Decompress2.
// Must be a multiple of two vectors because we Decompress2.
HWY_DASSERT(A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0);
HWY_DASSERT(C.cols % kRegCols == 0);
// We currently write C directly, which touches more memory than fits in L3.
@ -431,8 +461,10 @@ HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
const size_t tilesX = C.cols / kRegCols;
pool.Run(0, tilesX * tilesY,
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
env.Pool().Run(
0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR {
// TODO: when using PerClusterPool, compute lp from outer and inner.
float* HWY_RESTRICT buf = env.Buf(thread);
const size_t tx = idx_tile % tilesX;
const size_t ty = idx_tile / tilesX;
const size_t row_ac = ty * kRegRows;
@ -443,16 +475,16 @@ HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
HWY_DASSERT(num_rows != 0);
switch (num_rows) {
case 1:
MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
break;
case 2:
MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
break;
case 3:
MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
break;
default:
MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
}
});
}

97
ops/matmul.h Normal file
View File

@ -0,0 +1,97 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
#include <stddef.h>
#include "util/allocator.h" // RowVectorBatch
#include "util/threading.h" // PerClusterPools
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/per_target.h"
namespace gcpp {
// Bundles ptr/size/stride arguments to simplify MatMul call sites. T can be
// const or non-const. Create via ConstMat/MutableMat.
template <typename T>
struct Mat {
bool NotEmpty() const {
return ptr != nullptr && cols != 0 && stride >= cols;
}
size_t Row(size_t r) const { return ofs + stride * r; }
T* HWY_RESTRICT ptr;
size_t cols;
// elements between rows, which is typically the same as `cols`.
size_t stride;
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic.
size_t ofs;
};
template <typename T>
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
size_t ofs = 0) {
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
}
template <typename T>
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols, size_t stride,
size_t ofs = 0) {
return Mat<const T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
}
template <typename T>
Mat<const T> ConstMat(Mat<T> mat) {
return ConstMat(mat.ptr, mat.cols, mat.stride, mat.ofs);
}
template <typename T>
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols) {
return MutableMat(ptr, cols, cols);
}
template <typename T>
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols) {
return ConstMat(ptr, cols, cols);
}
// Allocations and threads, shared across MatMul calls.
class MatMulEnv {
public:
MatMulEnv() : pools_(nullptr) {}
explicit MatMulEnv(PerClusterPools& pools) : pools_(&pools) {
const size_t num_lp = pools.NumLP();
const size_t NF = hwy::VectorBytes() / sizeof(float);
buf_ = RowVectorBatch<float>(num_lp, 16 * NF);
}
float* HWY_RESTRICT Buf(size_t lp) { return buf_.Batch(lp); }
PerClusterPools& Pools() const { return *pools_; }
hwy::ThreadPool& Pool() const { return pools_->Inner(0); }
private:
RowVectorBatch<float> buf_;
PerClusterPools* pools_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_

View File

@ -24,6 +24,7 @@
#include <memory>
#include "compression/compress.h"
#include "util/threading.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -35,9 +36,10 @@
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/matmul-inl.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
@ -149,7 +151,7 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
const double tolerance = 50.0 * norm * epsilon;
const double tolerance = 200.0 * norm * epsilon;
for (size_t idx = 0; idx < num_c; idx++) {
const double expected_value = expected_c[idx];
@ -157,8 +159,10 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
expected_value, idx, actual_value);
fprintf(
stderr,
"expected[%lu]: %f, actual[%lu]: %f, norm %f eps %E tolerance %f\n",
idx, expected_value, idx, actual_value, norm, epsilon, tolerance);
HWY_ASSERT(0);
}
}
@ -202,14 +206,15 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
size_t cols_bc, double elapsed) {
// 2 because of FMA.
fprintf(stderr, "%s: %f seconds, %f GFLOPS.\n", algo, elapsed,
2E-9 * rows_ac * cols_a_rows_b * cols_bc / elapsed);
// 2x because of FMA.
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
elapsed, 2 * 1E-9 * rows_ac * cols_a_rows_b * cols_bc / elapsed);
}
template <size_t kRowsAC, size_t kColsARowsB, size_t kColsBC, bool kAdd,
typename MatTA, typename MatTB = MatTA>
void TestMatMul(hwy::ThreadPool& pool) {
void TestMatMul(MatMulEnv& env) {
hwy::ThreadPool& pool = env.Pool();
using TraitsA = CompressTraits<MatTA>;
using TraitsB = CompressTraits<MatTB>;
const bool want_bench = kColsBC > 2000; // avoid spam for small matrices
@ -247,14 +252,14 @@ void TestMatMul(hwy::ThreadPool& pool) {
double min_elapsed = hwy::HighestValue<double>();
for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) {
const double start_tiled = hwy::platform::Now();
MatMul_4x4<kAdd>(kRowsAC, MakeMat(a->data(), kColsARowsB),
MakeMat(b_trans->data(), kColsARowsB), scale,
kAdd ? add->data_scale1() : nullptr,
MakeMat(c.get(), kColsBC), pool);
MatMul<kAdd>(kRowsAC, ConstMat(a->data(), kColsARowsB),
ConstMat(b_trans->data(), kColsARowsB), scale,
kAdd ? add->data_scale1() : nullptr, env,
MutableMat(c.get(), kColsBC));
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
}
if (want_bench) {
PrintSpeed("MatMul_4x4", kRowsAC, kColsARowsB, kColsBC, min_elapsed);
PrintSpeed("MatMul", kRowsAC, kColsARowsB, kColsBC, min_elapsed);
}
AssertClose(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(),
@ -268,53 +273,56 @@ void TestAllMatMul() {
return;
}
hwy::ThreadPool pool(4);
PerClusterPools pools(/*max_clusters=*/1, /*max_threads=*/4, /*pin=*/1);
MatMulEnv env(pools);
using F32 = float;
using SFP = SfpStream;
// large-scale test
TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<64, 3072, 24576, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<64, 24576, 3072, /*kAdd=*/false, BF16, SFP>(env);
TestMatMul<64, 3072, 24576, /*kAdd=*/false, BF16, SFP>(env);
TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<64, 3072, 24576, /*kAdd=*/false, F32, SFP>(env);
// medium-sized square test
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(pool);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env);
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env);
// minimal non-square test. kColsARowsB must be at least 2 vectors.
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(pool);
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(pool);
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(pool);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(pool);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env);
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env);
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env);
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env);
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env);
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env);
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env);
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env);
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env);
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env);
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env);
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

75
util/allocator.h Normal file
View File

@ -0,0 +1,75 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
#include <stddef.h>
#include <stdint.h>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
namespace gcpp {
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
template <typename T>
ByteStorageT AllocateSizeof() {
return hwy::AllocateAligned<uint8_t>(sizeof(T));
}
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x len) matrix.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() : batch_size_(0), len_(0) {}
// Main ctor, called from Activations::Allocate.
RowVectorBatch(size_t batch_size, size_t len)
: batch_size_(batch_size), len_(len) {
mem_ = hwy::AllocateAligned<T>(batch_size * len);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return batch_size_; }
size_t Len() const { return len_; }
// Returns the given row vector of length `Len()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < batch_size_);
return mem_.get() + batch_idx * len_;
}
// For MatMul or other operations that process the entire batch at once.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
private:
hwy::AlignedFreeUniquePtr<T[]> mem_;
size_t batch_size_; // rows in the matrix
size_t len_; // columns in the matrix = vector length
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_

View File

@ -197,6 +197,11 @@ class PerClusterPools {
return *inner_pools_[outer];
}
// Returns number of logical processors, for allocating per-thread buffers.
size_t NumLP() const {
return outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers();
}
private:
bool have_threading_support_;
CoreBitSets cores_per_cluster_;