mirror of https://github.com/google/gemma.cpp.git
Major MatMul update, 1.9-2.3x speedup on Zen4 via bf16 mul
Supports converting all weight/activation formats to native MulT (bf16/f32) Also: - ConstMat/MutableMat for const correctness - Move RowVectorBatch to allocator.h so it can be used from Matmul - Add matmul.h so MatMulEnv can be used from Activations - Remove kMaxThreads, detect from PerClusterPools - Build fix: -inl.h files must be textual_hdrs, and highway.h should precede -inl.h ``` zen4 new 64, 24576, 3072, add=0, MatTA=bf16, MatTB=sfp: 616.6 GFLOPS. 64, 3072, 24576, add=0, MatTA=bf16, MatTB=sfp: 460.7 GFLOPS. 64, 24576, 3072, add=0, MatTA=f32, MatTB=sfp: 598.6 GFLOPS. 64, 3072, 24576, add=0, MatTA=f32, MatTB=sfp: 435.6 GFLOPS. zen4 old 64, 24576, 3072, add=0, MatTA=f32, MatTB=sfp: 257.5 GFLOPS. 64, 3072, 24576, add=0, MatTA=f32, MatTB=sfp: 231.9 GFLOPS. ``` PiperOrigin-RevId: 663729812
This commit is contained in:
parent
6c57feb52f
commit
301dc8067a
51
BUILD.bazel
51
BUILD.bazel
|
|
@ -20,14 +20,37 @@ licenses(["notice"])
|
|||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "allocator",
|
||||
hdrs = ["util/allocator.h"],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "threading",
|
||||
hdrs = ["util/threading.h"],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
"@hwy//:topology",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
hdrs = [
|
||||
"ops/matmul.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"ops/ops-inl.h",
|
||||
"ops/matmul-inl.h",
|
||||
"ops/matvec-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":threading",
|
||||
"//compression:compress",
|
||||
"//compression:sfp",
|
||||
"@hwy//:algo",
|
||||
|
|
@ -86,6 +109,7 @@ cc_test(
|
|||
tags = ["hwy_ops_test"],
|
||||
deps = [
|
||||
":ops",
|
||||
":threading",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
|
|
@ -114,6 +138,7 @@ cc_library(
|
|||
srcs = ["gemma/weights.cc"],
|
||||
hdrs = ["gemma/weights.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
|
|
@ -148,16 +173,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "threading",
|
||||
hdrs = ["util/threading.h"],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
"@hwy//:topology",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
|
|
@ -197,6 +212,7 @@ cc_library(
|
|||
# Placeholder for internal file2, do not remove,
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":ops",
|
||||
":tokenizer",
|
||||
|
|
@ -389,11 +405,14 @@ cc_library(
|
|||
hdrs = [
|
||||
"backprop/activations.h",
|
||||
"backprop/backward.h",
|
||||
"backprop/backward-inl.h",
|
||||
"backprop/forward.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"backprop/backward-inl.h",
|
||||
"backprop/forward-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
|
|
@ -413,6 +432,7 @@ cc_library(
|
|||
"backprop/forward_scalar.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":prompt",
|
||||
|
|
@ -467,13 +487,10 @@ cc_test(
|
|||
|
||||
cc_library(
|
||||
name = "optimizer",
|
||||
srcs = [
|
||||
"backprop/optimizer.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"backprop/optimizer.h",
|
||||
],
|
||||
srcs = ["backprop/optimizer.cc"],
|
||||
hdrs = ["backprop/optimizer.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
|
|
|
|||
|
|
@ -103,8 +103,10 @@ set(SOURCES
|
|||
ops/matmul-inl.h
|
||||
ops/matvec-inl.h
|
||||
ops/ops-inl.h
|
||||
util/allocator.h
|
||||
util/app.h
|
||||
util/args.h
|
||||
util/threading.h
|
||||
)
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "gemma/common.h" // ByteStorageT
|
||||
#include "util/allocator.h" // ByteStorageT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@
|
|||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h" // CreateInvTimescale
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -42,9 +41,10 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@
|
|||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
|
|
@ -29,8 +31,6 @@
|
|||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "backprop/backward-inl.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include <random>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -185,6 +185,7 @@ cc_library(
|
|||
name = "weights_raw",
|
||||
hdrs = ["weights_raw.h"],
|
||||
deps = [
|
||||
"//:allocator",
|
||||
"//:common",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
|
|
|
|||
|
|
@ -89,6 +89,22 @@ struct CompressTraits<float> {
|
|||
f1 = hn::LoadU(df, in + in_ofs + N);
|
||||
}
|
||||
|
||||
// Called by MatMul for f32 weights or activations if native
|
||||
// `ReorderWidenMulAccumulate` is available.
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16), class VBF16 = hn::Vec<DBF16>>
|
||||
static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, VBF16& v0, VBF16& v1) {
|
||||
const hn::Repartition<float, decltype(dbf16)> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
const VF f0 = hn::LoadU(df, in + in_ofs + 0 * NF);
|
||||
const VF f1 = hn::LoadU(df, in + in_ofs + 1 * NF);
|
||||
const VF f2 = hn::LoadU(df, in + in_ofs + 2 * NF);
|
||||
const VF f3 = hn::LoadU(df, in + in_ofs + 3 * NF);
|
||||
v0 = hn::OrderedDemote2To(dbf16, f0, f1);
|
||||
v1 = hn::OrderedDemote2To(dbf16, f2, f3);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
|
|
@ -196,6 +212,14 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
f1 = hn::PromoteUpperTo(df, in16);
|
||||
}
|
||||
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16)>
|
||||
static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<DBF16>& v0,
|
||||
hn::Vec<DBF16>& v1) {
|
||||
v0 = hn::LoadU(dbf16, in + in_ofs);
|
||||
v1 = hn::LoadU(dbf16, in + in_ofs + hn::Lanes(dbf16));
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
|
|
@ -318,14 +342,14 @@ struct CompressTraits<SfpStream> {
|
|||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
const hn::Twice<hn::Rebind<uint8_t, DF>> d8;
|
||||
template <class D> // f32 or bf16
|
||||
static HWY_INLINE void Decompress2(D d, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<D>& v0,
|
||||
hn::Vec<D>& v1) {
|
||||
const hn::Twice<hn::Rebind<uint8_t, D>> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
const V8 packed = hn::LoadU(d8, &in->byte + in_ofs);
|
||||
SfpCodec::Dec2F(df, packed, f0, f1);
|
||||
SfpCodec::Dec2(d, packed, v0, v1);
|
||||
}
|
||||
|
||||
template <class D, typename OutT>
|
||||
|
|
|
|||
|
|
@ -533,8 +533,8 @@ class SfpCodec {
|
|||
|
||||
template <class DF, HWY_IF_F32_D(DF),
|
||||
class V8 = hn::Vec<hn::Twice<hn::Rebind<uint8_t, DF>>>>
|
||||
static HWY_INLINE void Dec2F(DF df, V8 packed, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
static HWY_INLINE void Dec2(DF df, V8 packed, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
const hn::Rebind<hwy::bfloat16_t, DF> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
VBF bf0, bf1;
|
||||
|
|
@ -543,6 +543,13 @@ class SfpCodec {
|
|||
f1 = hn::PromoteTo(df, bf1);
|
||||
}
|
||||
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF16>>>
|
||||
static HWY_INLINE void Dec2(DBF16 dbf16, V8 packed, hn::Vec<DBF16>& bf0,
|
||||
hn::Vec<DBF16>& bf1) {
|
||||
Dec2B(dbf16, packed, bf0, bf1);
|
||||
}
|
||||
|
||||
private:
|
||||
// Wrappers to avoid code duplication across float/bf16 input types and
|
||||
// the main loop/remainder.
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@
|
|||
|
||||
#include <random>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
|
|||
|
|
@ -20,51 +20,14 @@
|
|||
|
||||
#include <cmath>
|
||||
|
||||
#include "gemma/common.h" // kMaxThreads - TODO: remove
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h" // HWY_DASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||
// This can be seen as a (batch_size x len) matrix.
|
||||
template <typename T>
|
||||
class RowVectorBatch {
|
||||
public:
|
||||
// Default ctor for Activations ctor.
|
||||
RowVectorBatch() : batch_size_(0), len_(0) {}
|
||||
// Main ctor, called from Activations::Allocate.
|
||||
RowVectorBatch(size_t batch_size, size_t len)
|
||||
: batch_size_(batch_size), len_(len) {
|
||||
mem_ = hwy::AllocateAligned<T>(batch_size * len);
|
||||
}
|
||||
|
||||
// Move-only
|
||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||
|
||||
size_t BatchSize() const { return batch_size_; }
|
||||
size_t Len() const { return len_; }
|
||||
|
||||
// Returns the given row vector of length `Len()`.
|
||||
T* Batch(size_t batch_idx) {
|
||||
HWY_DASSERT(batch_idx < batch_size_);
|
||||
return mem_.get() + batch_idx * len_;
|
||||
}
|
||||
|
||||
// For MatMul or other operations that process the entire batch at once.
|
||||
T* All() { return mem_.get(); }
|
||||
const T* Const() const { return mem_.get(); }
|
||||
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
|
||||
|
||||
private:
|
||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||
size_t batch_size_; // rows in the matrix
|
||||
size_t len_; // columns in the matrix = vector length
|
||||
};
|
||||
|
||||
struct Activations {
|
||||
RowVectorBatch<float> x; // input
|
||||
RowVectorBatch<float> q; // query, also KV if MHA.
|
||||
|
|
@ -94,9 +57,11 @@ struct Activations {
|
|||
|
||||
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
||||
// per-thread storage.
|
||||
// TODO: remove once MatVec is gone.
|
||||
// TODO: remove once MatVec is no longer used.
|
||||
RowVectorBatch<float> even_odd;
|
||||
|
||||
MatMulEnv env;
|
||||
|
||||
// Multi-Head Attention?
|
||||
template <class TConfig>
|
||||
static constexpr bool IsMHA() {
|
||||
|
|
@ -126,7 +91,7 @@ struct Activations {
|
|||
}
|
||||
|
||||
template <class TConfig>
|
||||
void Allocate(size_t batch_size) {
|
||||
void Allocate(size_t batch_size, PerClusterPools& pools) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
constexpr size_t kHeads = TConfig::kHeads;
|
||||
|
|
@ -158,7 +123,10 @@ struct Activations {
|
|||
|
||||
inv_timescale = CreateInvTimescale<TConfig>();
|
||||
|
||||
even_odd = RowVectorBatch<float>(1, kModelDim * kMaxThreads);
|
||||
const size_t num_lp = pools.NumLP();
|
||||
even_odd = RowVectorBatch<float>(1, kModelDim * num_lp);
|
||||
|
||||
env = MatMulEnv(pools);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -18,24 +18,15 @@
|
|||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
||||
|
||||
template <typename T>
|
||||
ByteStorageT AllocateSizeof() {
|
||||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||
}
|
||||
|
||||
// Model variants: see configs.h for details. When adding a new one, also
|
||||
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
|
||||
enum class Model {
|
||||
|
|
|
|||
|
|
@ -36,14 +36,8 @@ namespace gcpp {
|
|||
#define GEMMA_TOPK 1
|
||||
#endif // !GEMMA_TOPK
|
||||
|
||||
// Allow changing upper bound on threads as a compiler flag
|
||||
#ifndef GEMMA_MAX_THREADS
|
||||
#define GEMMA_MAX_THREADS 128
|
||||
#endif // !GEMMA_MAX_THREADS
|
||||
|
||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
|
||||
|
||||
using EmbedderInputT = hwy::bfloat16_t;
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@
|
|||
#include "ops/matmul-inl.h"
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -73,9 +74,10 @@ template <class TConfig>
|
|||
HWY_NOINLINE void GriffinRecurrent(
|
||||
size_t batch_start, size_t num_tokens, size_t layer,
|
||||
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||
const KVCaches& kv_caches) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
KVCache& kv_cache = kv_caches[0];
|
||||
hwy::ThreadPool& pool = activations.env.Pool();
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
|
|
@ -240,12 +242,12 @@ class GemmaAttention {
|
|||
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
||||
|
||||
const auto pre_att_rms_out =
|
||||
MakeMat(activations_.pre_att_rms_out.All(), kModelDim);
|
||||
MatMul_4x4</*kAdd=*/false>(
|
||||
ConstMat(activations_.pre_att_rms_out.All(), kModelDim);
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_interleaved, pre_att_rms_out,
|
||||
MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim),
|
||||
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
|
||||
MakeMat(activations_.q.All(), kHeads * kQStride), pool_);
|
||||
ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim),
|
||||
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, activations_.env,
|
||||
MutableMat(activations_.q.All(), kHeads * kQStride));
|
||||
|
||||
if constexpr (kIsMHA) {
|
||||
static_assert(TConfig::kInterleaveQKV, "MHA implies interleaved");
|
||||
|
|
@ -259,12 +261,13 @@ class GemmaAttention {
|
|||
queries_pos_[0] * kCachePosSize + layer_ * kCacheLayerSize;
|
||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||
MatMul_4x4</*kAdd=*/false>(
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_tokens_, pre_att_rms_out,
|
||||
MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim,
|
||||
kHeads * kQKVDim * kModelDim),
|
||||
ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim,
|
||||
kHeads * kQKVDim * kModelDim),
|
||||
layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr,
|
||||
MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool_);
|
||||
activations_.env,
|
||||
MutableMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize));
|
||||
} else {
|
||||
// Proceed row by row because there will be wraparound.
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
|
|
@ -430,19 +433,18 @@ class GemmaAttention {
|
|||
// Thus the [num_interleaved, kModelDim] matmul output is the sum over
|
||||
// heads. Compare gemma/modules.py:
|
||||
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
|
||||
MatMul_4x4<kAdd>(
|
||||
num_interleaved, MakeMat(activations_.att_out.All(), kHeads * kQKVDim),
|
||||
MakeMat(layer_weights_.att_weights.data(), kHeads * kQKVDim),
|
||||
layer_weights_.attn_vec_einsum_w.scale(), bias,
|
||||
MakeMat(activations_.att_sums.All(), kModelDim), pool_);
|
||||
MatMul<kAdd>(
|
||||
num_interleaved, ConstMat(activations_.att_out.All(), kHeads * kQKVDim),
|
||||
ConstMat(layer_weights_.att_weights.data(), kHeads * kQKVDim),
|
||||
layer_weights_.attn_vec_einsum_w.scale(), bias, activations_.env,
|
||||
MutableMat(activations_.att_sums.All(), kModelDim));
|
||||
}
|
||||
|
||||
public:
|
||||
GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
|
||||
Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool)
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
|
||||
: queries_pos_(queries_pos),
|
||||
num_queries_(queries_pos.size()),
|
||||
num_tokens_(num_tokens),
|
||||
|
|
@ -451,7 +453,7 @@ class GemmaAttention {
|
|||
layer_weights_(*layer_weights),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
pool_(pool) {
|
||||
pool_(activations.env.Pool()) {
|
||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||
}
|
||||
|
||||
|
|
@ -480,17 +482,17 @@ HWY_NOINLINE void Attention(LayerAttentionType type,
|
|||
size_t layer, Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||
const KVCaches& kv_caches) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention<TConfig>(queries_pos, num_tokens, layer, activations,
|
||||
layer_weights, div_seq_len, kv_caches, pool)();
|
||||
layer_weights, div_seq_len, kv_caches)();
|
||||
} else {
|
||||
// Only reached if the model is Griffin. `if constexpr` prevents generating
|
||||
// this code for non-Griffin models.
|
||||
if constexpr (TConfig::kGriffinLayers > 0) {
|
||||
HWY_ASSERT(queries_pos.size() == 1);
|
||||
GriffinRecurrent<TConfig>(queries_pos[0], num_tokens, layer, activations,
|
||||
layer_weights, kv_caches, pool);
|
||||
layer_weights, kv_caches);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -510,8 +512,7 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
|||
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
const CompressedLayer<TConfig>* layer_weights) {
|
||||
PROFILER_ZONE("Gen.FFW");
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
|
|
@ -519,10 +520,10 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
// MatMul expects col-major B, which is what we have: kModelDim consecutive
|
||||
// elements in memory, repeated kFFHiddenDim times.
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
const auto A = MakeMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||
const auto B1 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim);
|
||||
const auto B2 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim,
|
||||
kModelDim, kModelDim * kFFHiddenDim);
|
||||
const auto A = ConstMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||
const auto B1 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim);
|
||||
const auto B2 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim,
|
||||
kModelDim, kModelDim * kFFHiddenDim);
|
||||
const float scale = layer_weights->gating_einsum_w.scale();
|
||||
constexpr bool kAddBias = TConfig::kFFBiases;
|
||||
const float* bias1 = nullptr;
|
||||
|
|
@ -533,22 +534,23 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
bias2 = bias1 + kFFHiddenDim;
|
||||
output_bias = layer_weights->ffw_output_biases.data_scale1();
|
||||
}
|
||||
auto C1 = MakeMat(activations.C1.All(), kFFHiddenDim);
|
||||
auto C2 = MakeMat(activations.C2.All(), kFFHiddenDim);
|
||||
auto C1 = MutableMat(activations.C1.All(), kFFHiddenDim);
|
||||
auto C2 = MutableMat(activations.C2.All(), kFFHiddenDim);
|
||||
|
||||
// Will go through GELU.
|
||||
MatMul_4x4<kAddBias>(num_interleaved, A, B1, scale, bias1, C1, pool);
|
||||
MatMul<kAddBias>(num_interleaved, A, B1, scale, bias1, activations.env, C1);
|
||||
// What to multiply by.
|
||||
MatMul_4x4<kAddBias>(num_interleaved, A, B2, scale, bias2, C2, pool);
|
||||
MatMul<kAddBias>(num_interleaved, A, B2, scale, bias2, activations.env, C2);
|
||||
|
||||
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
||||
Activation<TConfig>(C1.ptr, C2.ptr, kFFHiddenDim * num_interleaved);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
MatMul_4x4<kAddBias>(num_interleaved, C1,
|
||||
MakeMat(layer_weights->linear_w.data(), kFFHiddenDim),
|
||||
layer_weights->linear_w.scale(), output_bias,
|
||||
MakeMat(activations.ffw_out.All(), kModelDim), pool);
|
||||
MatMul<kAddBias>(num_interleaved, ConstMat(C1),
|
||||
ConstMat(layer_weights->linear_w.data(), kFFHiddenDim),
|
||||
layer_weights->linear_w.scale(), output_bias,
|
||||
activations.env,
|
||||
MutableMat(activations.ffw_out.All(), kModelDim));
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
|
|
@ -594,8 +596,7 @@ template <class TConfig>
|
|||
HWY_NOINLINE void TransformerLayer(
|
||||
const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
|
||||
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool) {
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
const size_t num_interleaved = num_tokens * queries_pos.size();
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
|
|
@ -607,7 +608,7 @@ HWY_NOINLINE void TransformerLayer(
|
|||
activations.pre_att_rms_out.All(), kModelDim);
|
||||
|
||||
Attention<TConfig>(type, queries_pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, div_seq_len, kv_caches, pool);
|
||||
layer_weights, div_seq_len, kv_caches);
|
||||
|
||||
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
|
||||
activations.att_sums.All());
|
||||
|
|
@ -620,7 +621,7 @@ HWY_NOINLINE void TransformerLayer(
|
|||
layer_weights->pre_ffw_norm_scale.data_scale1(),
|
||||
activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||
|
||||
FFW<TConfig>(activations, num_interleaved, layer_weights, pool);
|
||||
FFW<TConfig>(activations, num_interleaved, layer_weights);
|
||||
|
||||
PostNorm<TConfig>(num_interleaved, layer_weights->post_ffw_norm_scale,
|
||||
activations.ffw_out.All());
|
||||
|
|
@ -630,120 +631,71 @@ HWY_NOINLINE void TransformerLayer(
|
|||
/*is_attention=*/false);
|
||||
}
|
||||
|
||||
// Prefill and Transformer() advance positions in-place.
|
||||
// Prefill() and Transformer() increment positions in-place.
|
||||
using QueriesMutablePos = hwy::Span<size_t>;
|
||||
|
||||
// Batches are important for amortizing loading weights over multiple tokens.
|
||||
// This is possible in prefill because we know all tokens beforehand, whereas
|
||||
// decode depends on the previous output token. However, each prefill batch of a
|
||||
// query requires that preceding batches already wrote to the KV cache, hence we
|
||||
// sequentially loop over token batches. We can reduce the number of iterations
|
||||
// by increasing the batch size, but this also increases arithmetic intensity,
|
||||
// and so we are eventually compute-limited. The tensor parallelism (number of
|
||||
// threads collaborating on MatMul) is also limited by the CPU topology:
|
||||
// fork/join barriers are slow(er) when some threads reside in a different NUMA
|
||||
// node. To allow more threads to help, we also support parallelizing over
|
||||
// queries in case GenerateBatch was called.
|
||||
//
|
||||
// Thus we have two-level parallelism:
|
||||
// - Outer: handles one 'qbatch' of entire queries. The set of outer workers
|
||||
// includes the main thread because it is the one that calls `Prefill`, and is
|
||||
// determined by the number of 'clusters' (shared L3 caches or sockets).
|
||||
// - Inner: each `outer` worker passes `inner_pools_[outer]` to
|
||||
// `TransformerLayer` for tensor-level parallelism, and processes
|
||||
// `tbatch_size` tokens from a single query at a time.
|
||||
//
|
||||
// This class holds the thread pools and one activation per outer worker. It is
|
||||
// NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt
|
||||
// to their num_queries.
|
||||
class PrefillState {
|
||||
public:
|
||||
// `tbatch_size` is the number of tokens from one query to prefill at a time.
|
||||
template <class TConfig>
|
||||
void Init(size_t num_queries, size_t tbatch_size, PerClusterPools& pools) {
|
||||
PROFILER_ZONE("Init.Prefill");
|
||||
HWY_ASSERT(num_queries != 0);
|
||||
HWY_ASSERT(activations_.empty()); // only call once.
|
||||
// Populates KV cache for batches of tokens from one query at a time.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Prefill(
|
||||
const QueriesPromptTokens& queries_prompt, const size_t prefill_per_query,
|
||||
const QueriesMutablePos& queries_pos, const size_t query_idx_start,
|
||||
const CompressedWeights<TConfig>& weights, Activations& activations,
|
||||
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
|
||||
// Allocate one activation per query, not outer worker, because the common
|
||||
// case is a single query. If we allocate the lesser of the two, it is
|
||||
// unclear how to choose an unused activation in Prefill.
|
||||
activations_.resize(num_queries);
|
||||
// Batches are important for amortizing loading weights over multiple tokens.
|
||||
// This is possible in prefill because we know all tokens beforehand, whereas
|
||||
// decode depends on the previous output token. However, each prefill batch of
|
||||
// a query requires that preceding batches already wrote to the KV cache,
|
||||
// hence we sequentially loop over token batches. We can reduce the number of
|
||||
// iterations by increasing the batch size, but this also increases arithmetic
|
||||
// intensity, and so we are eventually compute-limited. We could devote some
|
||||
// threads to parallelizing over queries, but for simplicity we assign them
|
||||
// all to MatMul.
|
||||
const size_t max_tbatch_size = activations.x.BatchSize();
|
||||
|
||||
if (num_queries == 1) {
|
||||
activations_[0].Allocate<TConfig>(tbatch_size);
|
||||
} else {
|
||||
// Allocating in parallel can save 30 ms. We might have more workers than
|
||||
// queries/tasks, so do not check the `thread` argument.
|
||||
pools.Outer().Run(0, num_queries,
|
||||
[this, tbatch_size](uint64_t qi, size_t /*thread*/) {
|
||||
activations_[qi].Allocate<TConfig>(tbatch_size);
|
||||
});
|
||||
}
|
||||
// For each query. `qi` is within the batch, not the global query index.
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
// Single query at a time, so pass slices of the spans because
|
||||
// GemmaAttention will only access the first KV cache and position.
|
||||
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
||||
KVCaches single_kv_cache(&kv_caches[qi], 1);
|
||||
|
||||
// For each batch of tokens in the query:
|
||||
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
|
||||
tbatch_start += max_tbatch_size) {
|
||||
// Fill activations.x (much faster than TransformerLayer).
|
||||
const size_t tbatch_size =
|
||||
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const int token = queries_prompt[qi][tbatch_start + ti];
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
EmbedToken<TConfig>(token, ti, pos, weights, activations.x);
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig>(single_query_pos, tbatch_size, layer,
|
||||
layer_weights, activations, div_seq_len,
|
||||
single_kv_cache);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
const int token = queries_prompt[qi][pos];
|
||||
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
|
||||
}
|
||||
|
||||
queries_pos[qi] += tbatch_size;
|
||||
} // for tbatch_start
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Prefill(const QueriesPromptTokens& queries_prompt,
|
||||
const size_t prefill_per_query,
|
||||
const QueriesMutablePos& queries_pos,
|
||||
const size_t query_idx_start,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
const size_t max_tbatch_size = activations_[0].x.BatchSize();
|
||||
|
||||
// For each query (parallel): an outer worker processes all its tokens.
|
||||
// `qi` is relative to the batch, not the global query index.
|
||||
pools.Outer().Run(
|
||||
0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR {
|
||||
Activations& activations = activations_[qi];
|
||||
hwy::ThreadPool& inner_pool = pools.Inner(qthread);
|
||||
|
||||
// Single query at a time, so pass slices of the spans because
|
||||
// GemmaAttention will only access the first KV cache and position.
|
||||
KVCaches single_kv_cache(&kv_caches[qi], 1);
|
||||
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
||||
|
||||
// For each batch of tokens in the query:
|
||||
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
|
||||
tbatch_start += max_tbatch_size) {
|
||||
// Fill activations.x (much faster than TransformerLayer).
|
||||
const size_t tbatch_size =
|
||||
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const int token = queries_prompt[qi][tbatch_start + ti];
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
EmbedToken<TConfig>(token, ti, pos, weights, activations.x);
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig>(single_query_pos, tbatch_size, layer,
|
||||
layer_weights, activations, div_seq_len,
|
||||
single_kv_cache, inner_pool);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
const int token = queries_prompt[qi][pos];
|
||||
runtime_config.StreamToken(query_idx_start + qi, pos, token,
|
||||
0.0f);
|
||||
}
|
||||
|
||||
queries_pos[qi] += tbatch_size;
|
||||
} // for tbatch_start
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Activations> activations_; // One per query, filled by Init.
|
||||
};
|
||||
}
|
||||
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
// from each query, and `queries_pos` are their position in the sequence.
|
||||
|
|
@ -752,7 +704,7 @@ HWY_NOINLINE void Transformer(
|
|||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||
const CompressedWeights<TConfig>& weights, Activations& activations,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool, const LayersOutputFunc& layers_output,
|
||||
const LayersOutputFunc& layers_output,
|
||||
const ActivationsObserverFunc& activations_observer) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
const size_t num_queries = queries_token.size();
|
||||
|
|
@ -775,7 +727,7 @@ HWY_NOINLINE void Transformer(
|
|||
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig>(queries_pos, /*num_tokens=*/1, layer,
|
||||
layer_weights, activations, div_seq_len,
|
||||
kv_caches, pool);
|
||||
kv_caches);
|
||||
|
||||
if (activations_observer) {
|
||||
activations_observer(queries_pos, layer, activations);
|
||||
|
|
@ -880,16 +832,12 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos_in, const size_t query_idx_start,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) {
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
const CompressedWeights<TConfig>& weights =
|
||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
|
||||
// TODO: remove once all parallel sections support hierarchical parallelism.
|
||||
hwy::ThreadPool& pool = pools.Inner(0);
|
||||
|
||||
// Copy so we can increment without requiring users to pass in a mutable span.
|
||||
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
||||
queries_pos_in.cend());
|
||||
|
|
@ -930,19 +878,22 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// Prefill stops before min_prompt_size - 1 because the last prompt token is
|
||||
// the first input token for generation.
|
||||
const size_t prefill_per_query = min_prompt_size - 1;
|
||||
double prefill_start;
|
||||
{
|
||||
// TODO: move to Gemma, reuse across calls to Generate.
|
||||
PrefillState prefill;
|
||||
prefill.Init<TConfig>(num_queries, runtime_config.prefill_tbatch_size,
|
||||
pools);
|
||||
prefill_start = hwy::platform::Now();
|
||||
prefill.Prefill<TConfig>(queries_prompt, prefill_per_query,
|
||||
queries_mutable_pos, query_idx_start, weights,
|
||||
runtime_config, div_seq_len, kv_caches, pools);
|
||||
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
||||
// queries_pos are incremented by Prefill.
|
||||
const double prefill_start = hwy::platform::Now();
|
||||
// If tbatch is larger than the qbatch we already have in `activations`, then
|
||||
// allocate prefill_activations, otherwise reuse.
|
||||
const bool use_prefill_activations =
|
||||
runtime_config.prefill_tbatch_size > activations.x.BatchSize();
|
||||
Activations prefill_activations;
|
||||
if (use_prefill_activations) {
|
||||
prefill_activations.Allocate<TConfig>(runtime_config.prefill_tbatch_size,
|
||||
activations.env.Pools());
|
||||
}
|
||||
Prefill<TConfig>(queries_prompt, prefill_per_query, queries_mutable_pos,
|
||||
query_idx_start, weights,
|
||||
use_prefill_activations ? prefill_activations : activations,
|
||||
runtime_config, div_seq_len, kv_caches);
|
||||
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
||||
// queries_pos are incremented by Prefill.
|
||||
|
||||
// Storage for the last generated token from each query, passed to the next
|
||||
// Transformer() call.
|
||||
|
|
@ -962,18 +913,18 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// Decode generates one token per query and increments queries_mutable_pos.
|
||||
Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries),
|
||||
queries_mutable_pos, weights, activations, div_seq_len,
|
||||
kv_caches, pool, runtime_config.layers_output,
|
||||
kv_caches, runtime_config.layers_output,
|
||||
runtime_config.activations_observer);
|
||||
// queries_pos are incremented by Transformer.
|
||||
|
||||
bool all_queries_eos = true;
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul_4x4</*kAdd=*/false>(
|
||||
num_queries, MakeMat(activations.x.All(), kModelDim),
|
||||
MakeMat(weights.embedder_input_embedding.data(), kModelDim),
|
||||
MatMul</*kAdd=*/false>(
|
||||
num_queries, ConstMat(activations.x.All(), kModelDim),
|
||||
ConstMat(weights.embedder_input_embedding.data(), kModelDim),
|
||||
weights.embedder_input_embedding.scale(), /*add=*/nullptr,
|
||||
MakeMat(activations.logits.All(), kVocabSize), pool);
|
||||
activations.env, MutableMat(activations.logits.All(), kVocabSize));
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||
MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
||||
|
|
@ -1001,15 +952,16 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
|
|||
constexpr size_t kNumQueries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
|
||||
// TODO: move into Gemma?
|
||||
Activations activations;
|
||||
activations.Allocate<TConfig>(kNumQueries);
|
||||
activations.Allocate<TConfig>(kNumQueries, pools);
|
||||
|
||||
const QueriesPromptTokens prompt_span(&prompt, kNumQueries);
|
||||
QueriesPos pos_span(&pos, kNumQueries);
|
||||
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
||||
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompt_span,
|
||||
pos_span, qbatch_start, kv_caches, pools, timing_info);
|
||||
pos_span, qbatch_start, kv_caches, timing_info);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
|
|
@ -1026,7 +978,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
|||
(TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size;
|
||||
|
||||
Activations activations;
|
||||
activations.Allocate<TConfig>(max_qbatch_size);
|
||||
activations.Allocate<TConfig>(max_qbatch_size, pools);
|
||||
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
qbatch_start += max_qbatch_size) {
|
||||
|
|
@ -1038,7 +990,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
|||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
|
||||
qbatch_pos, qbatch_start, qbatch_kv, pools, timing_info);
|
||||
qbatch_pos, qbatch_start, qbatch_kv, timing_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include "compression/compress.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
|
|||
770
ops/matmul-inl.h
770
ops/matmul-inl.h
|
|
@ -13,19 +13,10 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h" // temporarily disabled
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||
#include "compression/compress.h" // IWYU pragma: keep, b/conditionally used
|
||||
#include "ops/matmul.h" // IWYU pragma: export
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
|
|
@ -35,6 +26,8 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
|
||||
|
|
@ -43,355 +36,392 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// A square kernel minimizes the ratio of loads to FMA. 4x 128-bit corresponds
|
||||
// to one cache line.
|
||||
constexpr size_t kRegRows = 4;
|
||||
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
|
||||
// loads, we reuse the same A row for several B columns, which are also loaded
|
||||
// once for several rows of C. Thus we produce one 'tile' of C at a time of
|
||||
// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are
|
||||
// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4
|
||||
// enables the `StoreInterleaved4` transpose in `AddHorizontalSums`. We assume
|
||||
// and verify that `C.cols % kRegCols == 0`.
|
||||
constexpr size_t kRegCols = 4;
|
||||
|
||||
// Initializes a reg-tile of C: if kAdd, `add[add_ofs + c]`; otherwise 0.
|
||||
// `add` has no scale, and if `kAdd` is a row vector with A.cols entries,
|
||||
// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because
|
||||
// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile.
|
||||
// In general, `batch_size` (C rows) is not a multiple of `kRegRows`. Thus
|
||||
// functions that load or store a tile are parameterized on `kNumRows`, which is
|
||||
// generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
|
||||
constexpr size_t kRegRows = kRegCols;
|
||||
|
||||
// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are
|
||||
// more efficient than f32 * f32 + f32 because they process twice as many lanes
|
||||
// at a time. Any combination of A and B can be bf16: activations may already be
|
||||
// bf16, and weights can be decompressed to bf16.
|
||||
//
|
||||
// The corresponding op is `ReordenWidenMulAccumulate`, and it is always
|
||||
// supported, but only useful if it returns a single vector of pairwise sums
|
||||
// `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReordenWidenMulAccumulate`
|
||||
// insteads return `a[1] * b[1]` in its `sum1` output. We cannot afford to keep
|
||||
// a `sum1` for each of the `kRegRows * kRegCols` C vectors, and it would be
|
||||
// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B
|
||||
// to bf16 if the native op is available. This will actually demote f32
|
||||
// activations to bf16. Otherwise, we decompress to f32 and use normal FMA.
|
||||
using MulT = hwy::If<HWY_NATIVE_DOT_BF16, BF16, float>;
|
||||
|
||||
// Loads two vectors at a time with element type MulT from a row of transposed
|
||||
// B. Called in a loop over col_ab. No bounds checking because `kRow` is
|
||||
// actually from B columns, which we checked is a multiple of `kRegCols`.
|
||||
template <size_t kRow, typename MatTB>
|
||||
class BRow {
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
using TraitsB = CompressTraits<MatTB>;
|
||||
|
||||
public:
|
||||
BRow(const Mat<const MatTB>& B, size_t row_b)
|
||||
: B_(B.ptr), B_ofs_(B.Row(row_b + kRow)) {}
|
||||
|
||||
template <class DM, class VM = hn::Vec<DM>>
|
||||
HWY_INLINE void Load2(DM d, size_t col_ab, VM& b0, VM& b1) const {
|
||||
static_assert(hwy::IsSame<hn::TFromD<DM>, MulT>());
|
||||
TraitsB::Decompress2(d, B_, B_ofs_ + col_ab, b0, b1);
|
||||
}
|
||||
|
||||
private:
|
||||
const MatTB* HWY_RESTRICT B_;
|
||||
const size_t B_ofs_;
|
||||
};
|
||||
|
||||
// Loads *two* row vectors from A via `Decompress2`, multiplies element-wise
|
||||
// with `kRegRows` x 2 row vectors from transposed B, and adds them to
|
||||
// `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a subset of
|
||||
// the terms of the dot products that make up the MatMul result at `r,c`.
|
||||
// No-op for the bottom-most tile where kRow >= kNumRows.
|
||||
//
|
||||
// This approach is atypical because it requires a horizontal sum, for which we
|
||||
// introduce a fast and new(?) vector-length agnostic 'transpose', see
|
||||
// `AddHorizontalSums`. Most MatMul instead broadcast one element from A and
|
||||
// multiply with one element from N columns in B to obtain N columns of C.
|
||||
// This is a poor fit for our setting:
|
||||
// - `CompressTraits` decompresses two vectors at a time;
|
||||
// - B is column-major, so unit-stride SIMD loads return a column, not values
|
||||
// from different columns, i.e. a row.
|
||||
// Both could be fixed in a packing stage, which is not implemented yet, and
|
||||
// might not be necessary otherwise. However, `ReorderWidenMulAccumulate` is
|
||||
// important for bf16 performance and incompatible with the conventional
|
||||
// approach, because its pairwise adds would add together unrelated terms.
|
||||
// By contrast, pairwise adds are fine when our C lanes are the terms of a
|
||||
// single dot product, which can be reordered or pre-reduced.
|
||||
template <size_t kRow, typename MatTA>
|
||||
class ALoadAccumulate {
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
using TraitsA = CompressTraits<MatTA>;
|
||||
|
||||
public:
|
||||
ALoadAccumulate(const Mat<const MatTA>& A, size_t row_ac)
|
||||
: A_(A.ptr), A_ofs_(A.Row(row_ac + kRow)) {}
|
||||
|
||||
// First iteration, col_ab = 0: initialize C0..3 instead of updating them.
|
||||
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
|
||||
HWY_INLINE void First(DM dm, //
|
||||
const VM b00, const VM b01, const VM b10, const VM b11,
|
||||
const VM b20, const VM b21, const VM b30, const VM b31,
|
||||
VM& C0, VM& C1, VM& C2, VM& C3) const {
|
||||
static_assert(kNumRows <= kRegRows); // How many rows actually present
|
||||
if constexpr (kRow < kNumRows) {
|
||||
VM a0, a1;
|
||||
TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::Mul(a0, b00);
|
||||
C1 = hn::Mul(a0, b10);
|
||||
C2 = hn::Mul(a0, b20);
|
||||
C3 = hn::Mul(a0, b30);
|
||||
C0 = hn::MulAdd(a1, b01, C0);
|
||||
C1 = hn::MulAdd(a1, b11, C1);
|
||||
C2 = hn::MulAdd(a1, b21, C2);
|
||||
C3 = hn::MulAdd(a1, b31, C3);
|
||||
}
|
||||
}
|
||||
|
||||
// Same as above, only called if MulT == BF16.
|
||||
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>,
|
||||
HWY_IF_BF16_D(DM), class DF = hn::Repartition<float, DM>,
|
||||
class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void First(DM dm, //
|
||||
const VM b00, const VM b01, const VM b10, const VM b11,
|
||||
const VM b20, const VM b21, const VM b30, const VM b31,
|
||||
VF& C0, VF& C1, VF& C2, VF& C3) const {
|
||||
static_assert(kNumRows <= kRegRows); // How many rows actually present
|
||||
if constexpr (kRow < kNumRows) {
|
||||
VM a0, a1;
|
||||
TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1);
|
||||
|
||||
const DF df;
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::WidenMulPairwiseAdd(df, a0, b00);
|
||||
C1 = hn::WidenMulPairwiseAdd(df, a0, b10);
|
||||
C2 = hn::WidenMulPairwiseAdd(df, a0, b20);
|
||||
C3 = hn::WidenMulPairwiseAdd(df, a0, b30);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
}
|
||||
}
|
||||
|
||||
// Non-first iteration: accumulate into C0..3.
|
||||
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
|
||||
HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01,
|
||||
const VM b10, const VM b11, const VM b20, const VM b21,
|
||||
const VM b30, const VM b31, VM& C0, VM& C1, VM& C2,
|
||||
VM& C3) const {
|
||||
static_assert(kNumRows <= kRegRows); // How many rows actually present
|
||||
HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration.
|
||||
if constexpr (kRow < kNumRows) {
|
||||
VM a0, a1;
|
||||
TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::MulAdd(a0, b00, C0);
|
||||
C1 = hn::MulAdd(a0, b10, C1);
|
||||
C2 = hn::MulAdd(a0, b20, C2);
|
||||
C3 = hn::MulAdd(a0, b30, C3);
|
||||
C0 = hn::MulAdd(a1, b01, C0);
|
||||
C1 = hn::MulAdd(a1, b11, C1);
|
||||
C2 = hn::MulAdd(a1, b21, C2);
|
||||
C3 = hn::MulAdd(a1, b31, C3);
|
||||
}
|
||||
}
|
||||
|
||||
// Same as above, only called if MulT == BF16.
|
||||
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>,
|
||||
HWY_IF_BF16_D(DM), class DF = hn::Repartition<float, DM>,
|
||||
class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01,
|
||||
const VM b10, const VM b11, const VM b20, const VM b21,
|
||||
const VM b30, const VM b31, VF& C0, VF& C1, VF& C2,
|
||||
VF& C3) const {
|
||||
static_assert(kNumRows <= kRegRows); // How many rows actually present
|
||||
HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration.
|
||||
if constexpr (kRow < kNumRows) {
|
||||
VM a0, a1;
|
||||
TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
|
||||
|
||||
const DF df;
|
||||
hn::Vec<DF> unused_sum1 = hn::Zero(df);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1);
|
||||
C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1);
|
||||
C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1);
|
||||
C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1);
|
||||
C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1);
|
||||
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const MatTA* HWY_RESTRICT A_;
|
||||
const size_t A_ofs_;
|
||||
}; // ALoadAccumulate
|
||||
|
||||
// Sets a `kRegRows` x `kRegCols` tile of C to `add[add_ofs + c]` if kAdd,
|
||||
// otherwise 0.
|
||||
// `add` has no scale and is a row vector with A.cols entries if `kAdd`,
|
||||
// otherwise nullptr. In the latter case, adding `add_ofs` to it would be UB,
|
||||
// hence we pass it as a separate argument.
|
||||
template <size_t kNumRows, bool kAdd>
|
||||
HWY_INLINE void InitC(const float* HWY_RESTRICT add, size_t add_ofs,
|
||||
float* HWY_RESTRICT pos_c, size_t stride_c) {
|
||||
const hn::FixedTag<float, kRegCols> d4;
|
||||
for (size_t r = 0; r < HWY_MIN(kNumRows, kRegRows); ++r) {
|
||||
for (size_t c = 0; c < kRegCols; ++c) {
|
||||
if constexpr (kAdd) {
|
||||
pos_c[r * stride_c + c] = add[add_ofs + c];
|
||||
} else {
|
||||
pos_c[r * stride_c + c] = 0.0f;
|
||||
}
|
||||
if constexpr (kAdd) {
|
||||
hn::StoreU(hn::LoadU(d4, add + add_ofs), d4, pos_c + r * stride_c);
|
||||
} else {
|
||||
hn::StoreU(hn::Zero(d4), d4, pos_c + r * stride_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// c## are partial sums of the products of A and B; their horizontal sums are
|
||||
// the final matmul result, stored in C, which is always f32.
|
||||
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void AddHorizontalSums(DF df, float scale, //
|
||||
VF c00, VF c01, VF c02, VF c03, //
|
||||
VF c10, VF c11, VF c12, VF c13, //
|
||||
VF c20, VF c21, VF c22, VF c23, //
|
||||
VF c30, VF c31, VF c32, VF c33, //
|
||||
float* HWY_RESTRICT tile_c, size_t stride_c) {
|
||||
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||
// expensive, but only a fraction of the A.cols/N FMAs.
|
||||
// TODO: 4x4 transpose, then 128-bit vector FMA?
|
||||
tile_c[stride_c * 0 + 0] += scale * hn::ReduceSum(df, c00);
|
||||
tile_c[stride_c * 0 + 1] += scale * hn::ReduceSum(df, c01);
|
||||
tile_c[stride_c * 0 + 2] += scale * hn::ReduceSum(df, c02);
|
||||
tile_c[stride_c * 0 + 3] += scale * hn::ReduceSum(df, c03);
|
||||
if (kNumRows == 1) return;
|
||||
|
||||
tile_c[stride_c * 1 + 0] += scale * hn::ReduceSum(df, c10);
|
||||
tile_c[stride_c * 1 + 1] += scale * hn::ReduceSum(df, c11);
|
||||
tile_c[stride_c * 1 + 2] += scale * hn::ReduceSum(df, c12);
|
||||
tile_c[stride_c * 1 + 3] += scale * hn::ReduceSum(df, c13);
|
||||
if (kNumRows == 2) return;
|
||||
|
||||
tile_c[stride_c * 2 + 0] += scale * hn::ReduceSum(df, c20);
|
||||
tile_c[stride_c * 2 + 1] += scale * hn::ReduceSum(df, c21);
|
||||
tile_c[stride_c * 2 + 2] += scale * hn::ReduceSum(df, c22);
|
||||
tile_c[stride_c * 2 + 3] += scale * hn::ReduceSum(df, c23);
|
||||
if (kNumRows == 3) return;
|
||||
|
||||
tile_c[stride_c * 3 + 0] += scale * hn::ReduceSum(df, c30);
|
||||
tile_c[stride_c * 3 + 1] += scale * hn::ReduceSum(df, c31);
|
||||
tile_c[stride_c * 3 + 2] += scale * hn::ReduceSum(df, c32);
|
||||
tile_c[stride_c * 3 + 3] += scale * hn::ReduceSum(df, c33);
|
||||
}
|
||||
|
||||
// Wrapper to simplify call sites. T can be const or non-const.
|
||||
template <typename T>
|
||||
struct Mat {
|
||||
bool NotEmpty() const {
|
||||
return ptr != nullptr && cols != 0 && stride >= cols;
|
||||
// Accumulates into a tile of C.
|
||||
template <size_t kNumRows>
|
||||
class AddHorizontalSums {
|
||||
// These helper functions hoist if() out of the main code below. They have no
|
||||
// effect if kRow >= kNumRows.
|
||||
template <size_t kRow, class DF, class VF = hn::Vec<DF>>
|
||||
static void MaybeStoreInterleaved4(DF df, size_t N, VF Cr0, VF Cr1, VF Cr2,
|
||||
VF Cr3, float* HWY_RESTRICT buf) {
|
||||
if constexpr (kRow < kNumRows) {
|
||||
hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, df, buf + 4 * kRow * N);
|
||||
}
|
||||
}
|
||||
size_t Row(size_t r) const { return ofs + stride * r; }
|
||||
|
||||
T* HWY_RESTRICT ptr;
|
||||
size_t cols;
|
||||
// Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4.
|
||||
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
|
||||
static V4 MaybeLoad(D4 df, size_t N, const float* HWY_RESTRICT buf) {
|
||||
if constexpr (kRow < kNumRows) {
|
||||
return hn::Load(df, buf + 4 * kRow * N);
|
||||
} else {
|
||||
return hn::Zero(df);
|
||||
}
|
||||
}
|
||||
|
||||
// elements between rows, which is typically the same as `cols`.
|
||||
size_t stride;
|
||||
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
|
||||
static V4 MaybeAdd(D4 df, size_t N, V4 sum, const float* HWY_RESTRICT buf) {
|
||||
if constexpr (kRow < kNumRows) {
|
||||
return hn::Add(sum, hn::Load(df, buf + 4 * kRow * N));
|
||||
} else {
|
||||
return sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||
// pointer arithmetic.
|
||||
size_t ofs;
|
||||
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
|
||||
static void MaybeMulAdd(D4 df, V4 sum, V4 scale, float* HWY_RESTRICT tile_c,
|
||||
const size_t stride_c) {
|
||||
if constexpr (kRow < kNumRows) {
|
||||
const V4 prev_c = hn::LoadU(df, tile_c + kRow * stride_c);
|
||||
hn::StoreU(hn::MulAdd(sum, scale, prev_c), df, tile_c + kRow * stride_c);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Adds the contribution from `Crc` accumulators to the 4x4 tile of C whose
|
||||
// top left is `tile_c`, after multiplying by `scale`, which is the product of
|
||||
// the scales of A and B. C is always f32 to ensure sufficient precision.
|
||||
//
|
||||
// `Crc` are the 16 combinations of an A row vector indexed by `r`, times a
|
||||
// B column vector indexed by `c`. Their elements are thus a subset of the
|
||||
// terms of the dot product constituting the final `C[r, c]` result. Thus we
|
||||
// compute the horizontal sums of each `Crc`. The elements may be permuted
|
||||
// because we multiply bf16 via `ReorderWidenMulAccumulate`, but this does
|
||||
// not change their horizontal sum. `buf` is thread-local space for 16 `VF`.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void operator()(DF df, float scale, //
|
||||
VF C00, VF C01, VF C02, VF C03, //
|
||||
VF C10, VF C11, VF C12, VF C13, //
|
||||
VF C20, VF C21, VF C22, VF C23, //
|
||||
VF C30, VF C31, VF C32, VF C33, //
|
||||
float* HWY_RESTRICT buf,
|
||||
float* HWY_RESTRICT tile_c,
|
||||
size_t stride_c) const {
|
||||
const size_t N = hn::Lanes(df);
|
||||
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
|
||||
// log(N) operations for vectors of length N. Because kRegCols == 4, we can
|
||||
// instead use `StoreInterleaved4` for a vector length-agnostic 'transpose':
|
||||
// `buf[0, 4 * N)` holds C00[0], C01[0], C02[0], C03[0],
|
||||
// C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], C03[N-1].
|
||||
MaybeStoreInterleaved4<0>(df, N, C00, C01, C02, C03, buf);
|
||||
MaybeStoreInterleaved4<1>(df, N, C10, C11, C12, C13, buf);
|
||||
MaybeStoreInterleaved4<2>(df, N, C20, C21, C22, C23, buf);
|
||||
MaybeStoreInterleaved4<3>(df, N, C30, C31, C32, C33, buf);
|
||||
// Adding N consecutive V4 yields four horizontal sums of Cr0, Cr1, Cr2, Cr3
|
||||
// in the elements of one V4. We have four independent rows `r`, hence the
|
||||
// code is effectively unrolled, which increases throughput.
|
||||
const hn::FixedTag<float, 4> d4;
|
||||
using V4 = hn::Vec<decltype(d4)>;
|
||||
V4 sum0 = MaybeLoad<0>(d4, N, buf);
|
||||
V4 sum1 = MaybeLoad<1>(d4, N, buf);
|
||||
V4 sum2 = MaybeLoad<2>(d4, N, buf);
|
||||
V4 sum3 = MaybeLoad<3>(d4, N, buf);
|
||||
|
||||
for (size_t i = 1; i < N; ++i) {
|
||||
sum0 = MaybeAdd<0>(d4, N, sum0, buf + 4 * i);
|
||||
sum1 = MaybeAdd<1>(d4, N, sum1, buf + 4 * i);
|
||||
sum2 = MaybeAdd<2>(d4, N, sum2, buf + 4 * i);
|
||||
sum3 = MaybeAdd<3>(d4, N, sum3, buf + 4 * i);
|
||||
}
|
||||
// Scale, then store to four elements per row of `tile_c`.
|
||||
const V4 vscale = hn::Set(d4, scale);
|
||||
MaybeMulAdd<0>(d4, sum0, vscale, tile_c, stride_c);
|
||||
MaybeMulAdd<1>(d4, sum1, vscale, tile_c, stride_c);
|
||||
MaybeMulAdd<2>(d4, sum2, vscale, tile_c, stride_c);
|
||||
MaybeMulAdd<3>(d4, sum3, vscale, tile_c, stride_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
|
||||
size_t ofs = 0) {
|
||||
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols) {
|
||||
return MakeMat(ptr, cols, cols);
|
||||
}
|
||||
|
||||
// Inner loop of the kernel, called once per kRegRows. c[r] += a[c] * b[r,c].
|
||||
// The col_ab loop is unrolled 2x, so we have a0/a1 and b00/b01 etc.
|
||||
template <class VF>
|
||||
HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
|
||||
const VF& b01, const VF& b10, const VF& b11,
|
||||
const VF& b20, const VF& b21, const VF& b30,
|
||||
const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) {
|
||||
c0 = hn::MulAdd(a0, b00, c0);
|
||||
c1 = hn::MulAdd(a0, b10, c1);
|
||||
c2 = hn::MulAdd(a0, b20, c2);
|
||||
c3 = hn::MulAdd(a0, b30, c3);
|
||||
c0 = hn::MulAdd(a1, b01, c0);
|
||||
c1 = hn::MulAdd(a1, b11, c1);
|
||||
c2 = hn::MulAdd(a1, b21, c2);
|
||||
c3 = hn::MulAdd(a1, b31, c3);
|
||||
}
|
||||
|
||||
// Special case for the first iteration: c## are zero, so skip the first add.
|
||||
template <class VF>
|
||||
HWY_INLINE void FirstTileRow(const VF& a0, const VF& a1, const VF& b00,
|
||||
const VF& b01, const VF& b10, const VF& b11,
|
||||
const VF& b20, const VF& b21, const VF& b30,
|
||||
const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) {
|
||||
c0 = hn::Mul(a0, b00);
|
||||
c1 = hn::Mul(a0, b10);
|
||||
c2 = hn::Mul(a0, b20);
|
||||
c3 = hn::Mul(a0, b30);
|
||||
c0 = hn::MulAdd(a1, b01, c0);
|
||||
c1 = hn::MulAdd(a1, b11, c1);
|
||||
c2 = hn::MulAdd(a1, b21, c2);
|
||||
c3 = hn::MulAdd(a1, b31, c3);
|
||||
}
|
||||
|
||||
#undef GEMMA_NATIVE_BF16
|
||||
#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
|
||||
defined(HWY_TARGET_TOGGLE))
|
||||
#define GEMMA_NATIVE_BF16 1
|
||||
#else
|
||||
#define GEMMA_NATIVE_BF16 0
|
||||
#endif
|
||||
|
||||
#if GEMMA_NATIVE_BF16
|
||||
|
||||
// Specializations for f32 += bf16 * bf16 that avoid promoting to f32.
|
||||
|
||||
// Inner loop as above, but not unrolled. c[r] += a * b[r].
|
||||
template <class DF, class VF = hn::Vec<DF>,
|
||||
class VBF16 = hn::Vec<hn::Repartition<BF16, DF>>>
|
||||
HWY_INLINE void UpdateTileRow(DF df, const VBF16& a, const VBF16& b0,
|
||||
const VBF16& b1, const VBF16& b2, const VBF16& b3,
|
||||
VF& c0, VF& c1, VF& c2, VF& c3) {
|
||||
DF df;
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
c0 = hn::ReorderWidenMulAccumulate(df, a, b0, c0, unused_sum1);
|
||||
c1 = hn::ReorderWidenMulAccumulate(df, a, b1, c1, unused_sum1);
|
||||
c2 = hn::ReorderWidenMulAccumulate(df, a, b2, c2, unused_sum1);
|
||||
c3 = hn::ReorderWidenMulAccumulate(df, a, b3, c3, unused_sum1);
|
||||
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
}
|
||||
|
||||
// Special case for the first iteration: c## are zero, so skip the first add.
|
||||
template <class DF, class VF = hn::Vec<DF>,
|
||||
class VBF16 = hn::Vec<hn::Repartition<BF16, DF>>>
|
||||
HWY_INLINE void FirstTileRow(DF df, const VBF16& a, const VBF16& b0,
|
||||
const VBF16& b1, const VBF16& b2, const VBF16& b3,
|
||||
VF& c0, VF& c1, VF& c2, VF& c3) {
|
||||
c0 = hn::WidenMulPairwiseAdd(df, a, b0);
|
||||
c1 = hn::WidenMulPairwiseAdd(df, a, b1);
|
||||
c2 = hn::WidenMulPairwiseAdd(df, a, b2);
|
||||
c3 = hn::WidenMulPairwiseAdd(df, a, b3);
|
||||
}
|
||||
|
||||
template <size_t kNumRows, bool kAdd>
|
||||
HWY_INLINE void MatMulTile(const Mat<const BF16>& A, const Mat<const BF16>& B,
|
||||
const size_t row_ac, const size_t row_b_col_c,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
const Mat<float>& C) {
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
|
||||
// bf16 vectors.
|
||||
const hn::Repartition<BF16, decltype(df)> d;
|
||||
const size_t N = Lanes(d);
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
V b0, b1, b2, b3; // one from each row
|
||||
VF c00, c01, c02, c03;
|
||||
VF c10, c11, c12, c13;
|
||||
VF c20, c21, c22, c23;
|
||||
VF c30, c31, c32, c33;
|
||||
|
||||
const BF16* HWY_RESTRICT A_tile = A.ptr + A.Row(row_ac);
|
||||
const BF16* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c);
|
||||
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c;
|
||||
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.stride);
|
||||
|
||||
size_t col_ab = 0;
|
||||
|
||||
// First iteration initializes the c## vectors.
|
||||
{
|
||||
b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
|
||||
b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
|
||||
b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
|
||||
b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
|
||||
|
||||
{
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
|
||||
FirstTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03);
|
||||
}
|
||||
if constexpr (kNumRows > 1) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
|
||||
FirstTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13);
|
||||
}
|
||||
if constexpr (kNumRows > 2) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
|
||||
FirstTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23);
|
||||
}
|
||||
if constexpr (kNumRows == 3) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
|
||||
FirstTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33);
|
||||
}
|
||||
}
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (col_ab += N; col_ab < A.cols; col_ab += N) {
|
||||
b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
|
||||
b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
|
||||
b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
|
||||
b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
|
||||
|
||||
{
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
|
||||
UpdateTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03);
|
||||
}
|
||||
if constexpr (kNumRows > 1) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
|
||||
UpdateTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13);
|
||||
}
|
||||
if constexpr (kNumRows > 2) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
|
||||
UpdateTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23);
|
||||
}
|
||||
if constexpr (kNumRows == 3) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
|
||||
UpdateTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33);
|
||||
}
|
||||
}
|
||||
|
||||
AddHorizontalSums<kNumRows>(df, scale, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33, C_tile,
|
||||
C.stride);
|
||||
}
|
||||
|
||||
#endif // GEMMA_NATIVE_BF16
|
||||
|
||||
// Streams a `(kNumRows, 4)` strip of `A` and the transposed `B`, then writes a
|
||||
// finished tile of `C`.
|
||||
// General case: uses CompressTraits to load from A and B.
|
||||
// Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a
|
||||
// *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c).
|
||||
// TODO: loop over sections instead of full rows and accumulate into `tile_c`.
|
||||
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulTile(const Mat<MatTA>& A, const Mat<MatTB>& B,
|
||||
HWY_INLINE void MatMulTile(const Mat<const MatTA>& A, const Mat<const MatTB>& B,
|
||||
const size_t row_ac, const size_t row_b_col_c,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
const Mat<float>& C) {
|
||||
using TraitsA = CompressTraits<hwy::RemoveConst<MatTA>>;
|
||||
using TraitsB = CompressTraits<hwy::RemoveConst<MatTB>>;
|
||||
float* HWY_RESTRICT buf, const Mat<float>& C) {
|
||||
// For 'decompressing' A and B into BF16 or float.
|
||||
const hn::ScalableTag<MulT> dm;
|
||||
using VM = hn::Vec<decltype(dm)>;
|
||||
const size_t NM = hn::Lanes(dm);
|
||||
|
||||
const hn::ScalableTag<float> d32;
|
||||
const size_t N = hn::Lanes(d32);
|
||||
using V = hn::Vec<decltype(d32)>;
|
||||
V b00, b01, b10, b11, b20, b21, b30, b31; // two from each row
|
||||
V c00, c01, c02, c03;
|
||||
V c10, c11, c12, c13;
|
||||
V c20, c21, c22, c23;
|
||||
V c30, c31, c32, c33;
|
||||
static_assert(kRegRows == 4);
|
||||
const BRow<0, MatTB> b_row0(B, row_b_col_c);
|
||||
const BRow<1, MatTB> b_row1(B, row_b_col_c);
|
||||
const BRow<2, MatTB> b_row2(B, row_b_col_c);
|
||||
const BRow<3, MatTB> b_row3(B, row_b_col_c);
|
||||
|
||||
const size_t A_ofs = A.Row(row_ac);
|
||||
const size_t B_ofs = B.Row(row_b_col_c);
|
||||
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac);
|
||||
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac);
|
||||
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac);
|
||||
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac);
|
||||
|
||||
const hn::Repartition<float, decltype(dm)> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF C00, C01, C02, C03;
|
||||
VF C10, C11, C12, C13;
|
||||
VF C20, C21, C22, C23;
|
||||
VF C30, C31, C32, C33;
|
||||
|
||||
{ // First iteration initializes the `Crc` vectors.
|
||||
VM b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dm, /*col_ab=*/0, b00, b01);
|
||||
b_row1.Load2(dm, /*col_ab=*/0, b10, b11);
|
||||
b_row2.Load2(dm, /*col_ab=*/0, b20, b21);
|
||||
b_row3.Load2(dm, /*col_ab=*/0, b30, b31);
|
||||
|
||||
a_row0.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C00, C01, C02, C03);
|
||||
a_row1.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C10, C11, C12, C13);
|
||||
a_row2.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C20, C21, C22, C23);
|
||||
a_row3.template First<kNumRows>(dm, b00, b01, b10, b11, b20, b21, b30, b31,
|
||||
C30, C31, C32, C33);
|
||||
}
|
||||
|
||||
// `2 * NM` per iteration because `Load2` returns two vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 2 * NM; col_ab <= A.cols - 2 * NM; col_ab += 2 * NM) {
|
||||
VM b00, b01, b10, b11, b20, b21, b30, b31;
|
||||
b_row0.Load2(dm, col_ab, b00, b01);
|
||||
b_row1.Load2(dm, col_ab, b10, b11);
|
||||
b_row2.Load2(dm, col_ab, b20, b21);
|
||||
b_row3.Load2(dm, col_ab, b30, b31);
|
||||
|
||||
a_row0.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C00, C01, C02, C03);
|
||||
a_row1.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C10, C11, C12, C13);
|
||||
a_row2.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C20, C21, C22, C23);
|
||||
a_row3.template Next<kNumRows>(dm, col_ab, b00, b01, b10, b11, b20, b21,
|
||||
b30, b31, C30, C31, C32, C33);
|
||||
}
|
||||
|
||||
// TODO: hoist into outer loop.
|
||||
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c;
|
||||
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.stride);
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of 2*N
|
||||
// (since we are decoding consecutive bytes at each iteration).
|
||||
// Top-left of tile is (row_ac, col_ab) for A, and (row_b_col_c,
|
||||
// col_ab) for B. First iteration initializes the c## vectors.
|
||||
size_t col_ab = 0;
|
||||
|
||||
{
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01);
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11);
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21);
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31);
|
||||
|
||||
{
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1);
|
||||
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
||||
c02, c03);
|
||||
}
|
||||
if constexpr (kNumRows > 1) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1);
|
||||
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
||||
c12, c13);
|
||||
}
|
||||
if constexpr (kNumRows > 2) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1);
|
||||
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
||||
c22, c23);
|
||||
}
|
||||
if constexpr (kNumRows > 3) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1);
|
||||
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
||||
c32, c33);
|
||||
}
|
||||
}
|
||||
|
||||
// Main loop: accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (col_ab += 2 * N; col_ab <= A.cols - 2 * N; col_ab += 2 * N) {
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01);
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11);
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21);
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31);
|
||||
|
||||
{
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1);
|
||||
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
||||
c02, c03);
|
||||
}
|
||||
if constexpr (kNumRows > 1) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1);
|
||||
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
||||
c12, c13);
|
||||
}
|
||||
if constexpr (kNumRows > 2) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1);
|
||||
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
||||
c22, c23);
|
||||
}
|
||||
if constexpr (kNumRows > 3) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1);
|
||||
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
||||
c32, c33);
|
||||
}
|
||||
}
|
||||
|
||||
AddHorizontalSums<kNumRows>(d32, scale, c00, c01, c02, c03, c10, c11, c12,
|
||||
c13, c20, c21, c22, c23, c30, c31, c32, c33,
|
||||
C_tile, C.stride);
|
||||
AddHorizontalSums<kNumRows>()(df, scale, C00, C01, C02, C03, C10, C11, C12,
|
||||
C13, C20, C21, C22, C23, C30, C31, C32, C33,
|
||||
buf, C_tile, C.stride);
|
||||
}
|
||||
|
||||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||
|
|
@ -402,28 +432,28 @@ HWY_INLINE void MatMulTile(const Mat<MatTA>& A, const Mat<MatTB>& B,
|
|||
//
|
||||
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
||||
// values. When `A` and/or `B` are from CompressedArray, `scale` should be the
|
||||
// product of their `.scale()` values.
|
||||
// product of their `.scale()` values, otherwise 1.0f.
|
||||
//
|
||||
// If `kAdd` is true, the row-vector `add` is added to each row of `C`,
|
||||
// otherwise `add` is ignored and can be nullptr. A scale for `add` is not
|
||||
// supported, so make sure its scale is 1.
|
||||
//
|
||||
// `C` is a row-major matrix of size `(batch_size, C.cols)`.
|
||||
// Writes 4x4 tiles of C in parallel using a work-stealing thread pool.
|
||||
// Typically batch_size is 1..512, A.cols and C.cols are 3k or 24k.
|
||||
//
|
||||
// Updates 4x4 tiles of C in parallel using a work-stealing thread pool.
|
||||
// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k.
|
||||
// Must not be called concurrently with the same `env`.
|
||||
template <bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
|
||||
const Mat<MatTB>& B, const float scale,
|
||||
const float* HWY_RESTRICT add, const Mat<float>& C,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
|
||||
const Mat<const MatTB>& B, const float scale,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const Mat<float>& C) {
|
||||
// PROFILER_ZONE("Matmul");
|
||||
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty());
|
||||
HWY_DASSERT(A.cols == B.cols);
|
||||
|
||||
// Use float instead of MatTA/MatTB because we decompress to float here.
|
||||
const size_t N = hn::Lanes(hn::ScalableTag<float>());
|
||||
(void)N;
|
||||
HWY_DASSERT(A.cols % (N * 2) == 0); // For Decompress2.
|
||||
// Must be a multiple of two vectors because we Decompress2.
|
||||
HWY_DASSERT(A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0);
|
||||
HWY_DASSERT(C.cols % kRegCols == 0);
|
||||
|
||||
// We currently write C directly, which touches more memory than fits in L3.
|
||||
|
|
@ -431,30 +461,32 @@ HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
|
|||
const size_t tilesY = hwy::DivCeil(batch_size, kRegRows);
|
||||
const size_t tilesX = C.cols / kRegCols;
|
||||
|
||||
pool.Run(0, tilesX * tilesY,
|
||||
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t tx = idx_tile % tilesX;
|
||||
const size_t ty = idx_tile / tilesX;
|
||||
const size_t row_ac = ty * kRegRows;
|
||||
const size_t row_b_col_c = tx * kRegCols;
|
||||
// How many rows of C are left to compute. If more than 4, this
|
||||
// tile still only computes 4 rows.
|
||||
const size_t num_rows = batch_size - row_ac;
|
||||
HWY_DASSERT(num_rows != 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
|
||||
break;
|
||||
case 2:
|
||||
MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
|
||||
break;
|
||||
case 3:
|
||||
MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
|
||||
break;
|
||||
default:
|
||||
MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
|
||||
}
|
||||
});
|
||||
env.Pool().Run(
|
||||
0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR {
|
||||
// TODO: when using PerClusterPool, compute lp from outer and inner.
|
||||
float* HWY_RESTRICT buf = env.Buf(thread);
|
||||
const size_t tx = idx_tile % tilesX;
|
||||
const size_t ty = idx_tile / tilesX;
|
||||
const size_t row_ac = ty * kRegRows;
|
||||
const size_t row_b_col_c = tx * kRegCols;
|
||||
// How many rows of C are left to compute. If more than 4, this
|
||||
// tile still only computes 4 rows.
|
||||
const size_t num_rows = batch_size - row_ac;
|
||||
HWY_DASSERT(num_rows != 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
|
||||
break;
|
||||
case 2:
|
||||
MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
|
||||
break;
|
||||
case 3:
|
||||
MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
|
||||
break;
|
||||
default:
|
||||
MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/threading.h" // PerClusterPools
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/per_target.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Bundles ptr/size/stride arguments to simplify MatMul call sites. T can be
|
||||
// const or non-const. Create via ConstMat/MutableMat.
|
||||
template <typename T>
|
||||
struct Mat {
|
||||
bool NotEmpty() const {
|
||||
return ptr != nullptr && cols != 0 && stride >= cols;
|
||||
}
|
||||
size_t Row(size_t r) const { return ofs + stride * r; }
|
||||
|
||||
T* HWY_RESTRICT ptr;
|
||||
size_t cols;
|
||||
|
||||
// elements between rows, which is typically the same as `cols`.
|
||||
size_t stride;
|
||||
|
||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||
// pointer arithmetic.
|
||||
size_t ofs;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride,
|
||||
size_t ofs = 0) {
|
||||
return Mat<T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols, size_t stride,
|
||||
size_t ofs = 0) {
|
||||
return Mat<const T>{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<const T> ConstMat(Mat<T> mat) {
|
||||
return ConstMat(mat.ptr, mat.cols, mat.stride, mat.ofs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<T> MutableMat(T* HWY_RESTRICT ptr, size_t cols) {
|
||||
return MutableMat(ptr, cols, cols);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols) {
|
||||
return ConstMat(ptr, cols, cols);
|
||||
}
|
||||
|
||||
// Allocations and threads, shared across MatMul calls.
|
||||
class MatMulEnv {
|
||||
public:
|
||||
MatMulEnv() : pools_(nullptr) {}
|
||||
explicit MatMulEnv(PerClusterPools& pools) : pools_(&pools) {
|
||||
const size_t num_lp = pools.NumLP();
|
||||
const size_t NF = hwy::VectorBytes() / sizeof(float);
|
||||
buf_ = RowVectorBatch<float>(num_lp, 16 * NF);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT Buf(size_t lp) { return buf_.Batch(lp); }
|
||||
PerClusterPools& Pools() const { return *pools_; }
|
||||
hwy::ThreadPool& Pool() const { return pools_->Inner(0); }
|
||||
|
||||
private:
|
||||
RowVectorBatch<float> buf_;
|
||||
PerClusterPools* pools_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
||||
|
|
@ -24,6 +24,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -35,9 +36,10 @@
|
|||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
@ -149,7 +151,7 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
|
|||
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
|
||||
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
|
||||
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
|
||||
const double tolerance = 50.0 * norm * epsilon;
|
||||
const double tolerance = 200.0 * norm * epsilon;
|
||||
|
||||
for (size_t idx = 0; idx < num_c; idx++) {
|
||||
const double expected_value = expected_c[idx];
|
||||
|
|
@ -157,8 +159,10 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
|
|||
|
||||
if (!(expected_value - tolerance <= actual_value &&
|
||||
actual_value <= expected_value + tolerance)) {
|
||||
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
|
||||
expected_value, idx, actual_value);
|
||||
fprintf(
|
||||
stderr,
|
||||
"expected[%lu]: %f, actual[%lu]: %f, norm %f eps %E tolerance %f\n",
|
||||
idx, expected_value, idx, actual_value, norm, epsilon, tolerance);
|
||||
HWY_ASSERT(0);
|
||||
}
|
||||
}
|
||||
|
|
@ -202,14 +206,15 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
|
|||
|
||||
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
|
||||
size_t cols_bc, double elapsed) {
|
||||
// 2 because of FMA.
|
||||
fprintf(stderr, "%s: %f seconds, %f GFLOPS.\n", algo, elapsed,
|
||||
2E-9 * rows_ac * cols_a_rows_b * cols_bc / elapsed);
|
||||
// 2x because of FMA.
|
||||
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
|
||||
elapsed, 2 * 1E-9 * rows_ac * cols_a_rows_b * cols_bc / elapsed);
|
||||
}
|
||||
|
||||
template <size_t kRowsAC, size_t kColsARowsB, size_t kColsBC, bool kAdd,
|
||||
typename MatTA, typename MatTB = MatTA>
|
||||
void TestMatMul(hwy::ThreadPool& pool) {
|
||||
void TestMatMul(MatMulEnv& env) {
|
||||
hwy::ThreadPool& pool = env.Pool();
|
||||
using TraitsA = CompressTraits<MatTA>;
|
||||
using TraitsB = CompressTraits<MatTB>;
|
||||
const bool want_bench = kColsBC > 2000; // avoid spam for small matrices
|
||||
|
|
@ -247,14 +252,14 @@ void TestMatMul(hwy::ThreadPool& pool) {
|
|||
double min_elapsed = hwy::HighestValue<double>();
|
||||
for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) {
|
||||
const double start_tiled = hwy::platform::Now();
|
||||
MatMul_4x4<kAdd>(kRowsAC, MakeMat(a->data(), kColsARowsB),
|
||||
MakeMat(b_trans->data(), kColsARowsB), scale,
|
||||
kAdd ? add->data_scale1() : nullptr,
|
||||
MakeMat(c.get(), kColsBC), pool);
|
||||
MatMul<kAdd>(kRowsAC, ConstMat(a->data(), kColsARowsB),
|
||||
ConstMat(b_trans->data(), kColsARowsB), scale,
|
||||
kAdd ? add->data_scale1() : nullptr, env,
|
||||
MutableMat(c.get(), kColsBC));
|
||||
min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled);
|
||||
}
|
||||
if (want_bench) {
|
||||
PrintSpeed("MatMul_4x4", kRowsAC, kColsARowsB, kColsBC, min_elapsed);
|
||||
PrintSpeed("MatMul", kRowsAC, kColsARowsB, kColsBC, min_elapsed);
|
||||
}
|
||||
|
||||
AssertClose(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(),
|
||||
|
|
@ -268,53 +273,56 @@ void TestAllMatMul() {
|
|||
return;
|
||||
}
|
||||
|
||||
hwy::ThreadPool pool(4);
|
||||
PerClusterPools pools(/*max_clusters=*/1, /*max_threads=*/4, /*pin=*/1);
|
||||
MatMulEnv env(pools);
|
||||
using F32 = float;
|
||||
using SFP = SfpStream;
|
||||
|
||||
// large-scale test
|
||||
TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(pool);
|
||||
TestMatMul<64, 3072, 24576, /*kAdd=*/false, F32, SFP>(pool);
|
||||
TestMatMul<64, 24576, 3072, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<64, 3072, 24576, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<64, 3072, 24576, /*kAdd=*/false, F32, SFP>(env);
|
||||
|
||||
// medium-sized square test
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(pool);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(pool);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(pool);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(pool);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(pool);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(pool);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env);
|
||||
|
||||
// minimal non-square test. kColsARowsB must be at least 2 vectors.
|
||||
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(pool);
|
||||
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(pool);
|
||||
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
|
||||
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
|
||||
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
|
||||
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(pool);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(pool);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(pool);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(pool);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(pool);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(pool);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(pool);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(pool);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(pool);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(pool);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(pool);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(pool);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(pool);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(pool);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(pool);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(pool);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(pool);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(pool);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(pool);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(pool);
|
||||
TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env);
|
||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env);
|
||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
||||
|
||||
template <typename T>
|
||||
ByteStorageT AllocateSizeof() {
|
||||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||
}
|
||||
|
||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||
// This can be seen as a (batch_size x len) matrix.
|
||||
template <typename T>
|
||||
class RowVectorBatch {
|
||||
public:
|
||||
// Default ctor for Activations ctor.
|
||||
RowVectorBatch() : batch_size_(0), len_(0) {}
|
||||
// Main ctor, called from Activations::Allocate.
|
||||
RowVectorBatch(size_t batch_size, size_t len)
|
||||
: batch_size_(batch_size), len_(len) {
|
||||
mem_ = hwy::AllocateAligned<T>(batch_size * len);
|
||||
}
|
||||
|
||||
// Move-only
|
||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||
|
||||
size_t BatchSize() const { return batch_size_; }
|
||||
size_t Len() const { return len_; }
|
||||
|
||||
// Returns the given row vector of length `Len()`.
|
||||
T* Batch(size_t batch_idx) {
|
||||
HWY_DASSERT(batch_idx < batch_size_);
|
||||
return mem_.get() + batch_idx * len_;
|
||||
}
|
||||
|
||||
// For MatMul or other operations that process the entire batch at once.
|
||||
T* All() { return mem_.get(); }
|
||||
const T* Const() const { return mem_.get(); }
|
||||
size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); }
|
||||
|
||||
private:
|
||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||
size_t batch_size_; // rows in the matrix
|
||||
size_t len_; // columns in the matrix = vector length
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
|
||||
|
|
@ -197,6 +197,11 @@ class PerClusterPools {
|
|||
return *inner_pools_[outer];
|
||||
}
|
||||
|
||||
// Returns number of logical processors, for allocating per-thread buffers.
|
||||
size_t NumLP() const {
|
||||
return outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers();
|
||||
}
|
||||
|
||||
private:
|
||||
bool have_threading_support_;
|
||||
CoreBitSets cores_per_cluster_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue