Compare commits

...

17 Commits

Author SHA1 Message Date
Copybara-Service 18f6d43fcc Merge pull request #169 from xinpingwang:cmake-install
PiperOrigin-RevId: 630425203
2024-05-03 10:16:46 -07:00
Wang Xinping 2c038e1285 work with cmake install 2024-05-03 23:44:12 +08:00
Copybara-Service 8ed22e52bf Merge pull request #177 from szabadka:gemma2
PiperOrigin-RevId: 630388843
2024-05-03 07:52:27 -07:00
Zoltan Szabadka 19017fdb6d Fix expression in DASSERT() 2024-05-03 13:54:20 +00:00
Phil Culliton 28ca001d5e Matmul and test functions
PiperOrigin-RevId: 630373984
2024-05-03 06:39:36 -07:00
Zoltan Szabadka 429eb78512 Remove unused vars. 2024-05-03 13:37:17 +00:00
Zoltan Szabadka 3d72f17261 Use more parallelism in attention block in prefill mode.
Move the loop over the tokens inside the attention block and
then create kHeads * num_tokens threads.

This helps the multi-threaded speed only in case of the 2b gemma
model, but to be consistent we move the loop over the tokens inside
the griffin recurrent layer and the FFW layer as well. This is
also a preparation for using the MatMul operation later.

Benchmark results (summarization with 1600 tokens for prefill
and essay writing with 500 tokens for generation):

```
                   Prefill speed
Num threads      BEFORE       AFTER
32               61.76 t/s    65.08 t/s
64               89.46 t/s    98.62 t/s
```
2024-05-03 13:23:07 +00:00
Copybara-Service 6eeef2e2d9 Merge pull request #166 from samkaufman:deinterleave-vecs
PiperOrigin-RevId: 630360778
2024-05-03 05:23:31 -07:00
Sam Kaufman 4a6173d929 Remove unused vars. 2024-05-02 00:41:44 -07:00
Sam Kaufman 564937ede6 Merge branch 'dev' into deinterleave-vecs 2024-04-30 16:23:04 -07:00
Sam Kaufman 2829ef17ad Check for HWY_NATIVE_DOT_BF16. 2024-04-30 15:19:28 -07:00
Sam Kaufman 59ebecce22 Fix: specialized MatVecAdd was never called. 2024-04-30 15:17:27 -07:00
Sam Kaufman 6a78a23f4c Abstracted some MatVecAdd spec. dupes. 2024-04-29 16:23:38 -07:00
Sam Kaufman f608337fef Remove Bf16ToF32EO and use PromoteEvenTo and PromoteOddTo. 2024-04-29 14:13:07 -07:00
Sam Kaufman aa0b113214 (VecT*) to static_cast<VecT*>. 2024-04-29 12:53:47 -07:00
Sam Kaufman 5cb63346aa supports_eo -> kSupportsEvenOdd 2024-04-29 12:51:35 -07:00
Sam Kaufman 0816a1070d Even-odd layout MatVecs for bf16 weights. 2024-04-28 20:09:25 -07:00
5 changed files with 545 additions and 207 deletions

View File

@ -78,15 +78,17 @@ set_property(TARGET libgemma PROPERTY CXX_STANDARD 17)
set_target_properties(libgemma PROPERTIES PREFIX "")
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(libgemma PUBLIC ./)
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static)
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS libgemma DESTINATION lib)
# Executable Target
add_executable(gemma gemma/run.cc)
target_link_libraries(gemma libgemma hwy hwy_contrib)
install(TARGETS gemma DESTINATION bin)
add_executable(benchmark gemma/benchmark.cc)
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)

View File

@ -58,6 +58,7 @@ struct CompressTraits {};
template <>
struct CompressTraits<float> {
using MatT = float;
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
@ -111,6 +112,7 @@ struct CompressTraits<float> {
template <>
struct CompressTraits<hwy::bfloat16_t> {
using MatT = hwy::bfloat16_t;
static constexpr bool kSupportsEvenOdd = true;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
@ -219,11 +221,59 @@ struct CompressTraits<hwy::bfloat16_t> {
// bf16*bf16.
return hn::Dot::Compute<kAssumptions>(d_vec, vec_aligned, in + in_ofs, num);
}
// Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned`
// and a column- major matrix `in`. `vec_aligned` should be aligned and
// alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed
// `hn::Lanes(df32)` elements.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE float DotEO(
const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs,
const float* HWY_RESTRICT vec_aligned, size_t num) {
HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && (num % (hn::Lanes(df32) * 2)) == 0);
HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0);
HWY_DASSERT(hn::IsAligned(df32, vec_aligned));
const hn::Repartition<hwy::bfloat16_t, DF> dbf16;
using VF32 = decltype(Zero(df32));
const size_t N = Lanes(dbf16);
VF32 sum0 = Zero(df32);
VF32 sum1 = Zero(df32);
VF32 sum2 = Zero(df32);
VF32 sum3 = Zero(df32);
const hn::RebindToUnsigned<decltype(df32)> du32;
using VU32 = hn::VFromD<decltype(du32)>;
const VU32 odd = Set(du32, 0xFFFF0000u);
for (size_t i = 0; i < num; /* i += 2 * N */) {
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
const VF32 ae0 = Load(df32, vec_aligned + i);
const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2));
sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0);
sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1);
i += N;
const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i);
const VF32 ae1 = Load(df32, vec_aligned + i);
const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2));
sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2);
sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3);
i += N;
}
sum0 = Add(sum0, sum1);
sum2 = Add(sum2, sum3);
sum0 = Add(sum0, sum2);
return ReduceSum(df32, sum0);
}
};
template <>
struct CompressTraits<SfpStream> {
using MatT = SfpStream;
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
@ -273,6 +323,7 @@ struct CompressTraits<SfpStream> {
template <>
struct CompressTraits<NuqStream> {
using MatT = NuqStream;
static constexpr bool kSupportsEvenOdd = false;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
@ -425,16 +476,22 @@ HWY_INLINE float Dot(DF df, const ArrayT& compressed, size_t compressed_ofs,
}
// Returns dot product with `vec_aligned` of length `num`.
template <class DF, typename MatT, size_t kCapacity, typename VecT>
template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
size_t compressed_ofs, const VecT* vec_aligned,
size_t num) {
HWY_DASSERT(compressed_ofs + num <= compressed.size());
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
using Traits = CompressTraits<MatT>;
return (compressed.scale() * Traits::Dot(df, compressed.size(),
compressed.data(), compressed_ofs,
vec_aligned, num));
float dot_result;
if constexpr (kVecEO) {
dot_result = Traits::DotEO(df, compressed.data(), compressed_ofs,
vec_aligned, num);
} else {
dot_result = Traits::Dot(df, compressed.size(), compressed.data(),
compressed_ofs, vec_aligned, num);
}
return compressed.scale() * dot_result;
}
// Callback used by ForeachTensor.

View File

@ -460,7 +460,7 @@ KVCache CreateKVCacheT() {
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
return CreateKVCache(
Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen,
Config::kSeqLen + kPrefillBatchSize,
Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
Config::kModelDim,
Config::kGriffinLayers * Config::kModelDim);
@ -569,34 +569,38 @@ namespace HWY_NAMESPACE {
template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t batch_idx, size_t layer,
size_t batch_start, size_t num_tokens, size_t layer,
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Griffin");
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
HWY_DASSERT(batch_idx < kBatchSize);
HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr bool kAdd = true;
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
// X / Y linear layers.
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
activations.pre_att_rms_out.data() + batch_offset,
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool);
Gelu(y, kModelDim);
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
TwoMatVecAdd<kAdd, kModelDim, kModelDim>(
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
activations.pre_att_rms_out.data() + batch_offset,
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool);
Gelu(y, kModelDim);
}
// Conv1D.
{
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
HWY_FULL(float) df;
HWY_DASSERT(kModelDim % Lanes(df) == 0);
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
@ -611,14 +615,15 @@ HWY_NOINLINE void GriffinRecurrent(
}
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
auto accum1 = hn::Zero(df);
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() +
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() +
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
}
@ -628,68 +633,79 @@ HWY_NOINLINE void GriffinRecurrent(
}
// RGLRU
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset;
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset;
float* HWY_RESTRICT rnn_state =
kv_cache.rglru_cache.get() + layer * kModelDim;
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
float* HWY_RESTRICT gate_x =
activations.griffin_gate_x.data() + batch_offset;
float* HWY_RESTRICT a =
activations.griffin_multiplier.data() + batch_offset;
float* HWY_RESTRICT rnn_state =
kv_cache.rglru_cache.get() + layer * kModelDim;
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
constexpr size_t kHeadDim = kModelDim / kHeads;
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop<kAdd, kHeadDim, kHeadDim>(
layer_weights->griffin.gate_w, kMatrixSize * head,
kMatrixSize * (kHeads + head), x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
/*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.data() + head_offset, fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_FULL(float) df;
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0);
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
constexpr size_t kHeadDim = kModelDim / kHeads;
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop<kAdd, kHeadDim, kHeadDim>(
layer_weights->griffin.gate_w, kMatrixSize * head,
kMatrixSize * (kHeads + head), x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
/*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.data() + head_offset, fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_FULL(float) df;
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
});
});
}
// Final linear layer.
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<kAdd, kModelDim, kModelDim>(
layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin.linear_out_biases.data(),
activations.even_odd.data(), out_ptr, pool);
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t batch_offset = batch_idx * kModelDim;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<kAdd, kModelDim, kModelDim>(
layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin.linear_out_biases.data(),
activations.even_odd.data(), out_ptr, pool);
}
}
template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
Activations<TConfig, kBatchSize>& activations,
const LayerT* layer_weights, KVCache& kv_cache,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention");
const size_t pos = batch_start + batch_idx;
HWY_DASSERT(batch_idx < kBatchSize);
HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
static constexpr size_t kCachePosSize =
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
@ -699,47 +715,43 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
size_t cache_pos = pos;
size_t cache_num = pos + 1;
if constexpr (TConfig::kUseLocalAttention) {
cache_pos %= TConfig::kSeqLen;
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
}
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
auto Attn = [&](float* q, uint64_t head, size_t head_offset,
auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx,
size_t thread) HWY_ATTR {
const size_t pos = batch_start + batch_idx;
// Calculate scores
float* HWY_RESTRICT head_att = activations.att.data() +
head * TConfig::kSeqLen +
batch_idx * kHeads * kQKVDim;
head * kSeqLen +
batch_idx * kHeads * kSeqLen;
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset;
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
head_att[pos2 % kSeqLen] = score;
}
Softmax(head_att, cache_num);
Softmax(head_att, std::min(pos + 1, kSeqLen));
// Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
batch_idx * kHeads * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
}
};
@ -747,74 +759,99 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
// Multi-Head Attention
static_assert(TConfig::kInterleaveQKV);
float* HWY_RESTRICT qkv =
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
MatVec<kHeads * kQKVDim * 3, kModelDim>(
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
pool);
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
float* HWY_RESTRICT q = qkv + head * kQKVDim * 3;
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT qkv =
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
MatVec<kHeads * kQKVDim * 3, kModelDim>(
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
pool);
}
const size_t num_tasks = kHeads * num_tokens;
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT q =
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
Attn(q, head, head * kQKVDim * 2, thread);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
float* HWY_RESTRICT q =
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
Attn(q, head, head * kQKVDim * 2, batch_idx, thread);
});
} else {
// Multi-Query Attention
float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim;
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
activations.even_odd.data(), q, pool);
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = batch_start + batch_idx;
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() +
cache_pos * kCachePosSize +
layer * kCacheLayerSize;
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool);
float* HWY_RESTRICT q =
activations.q.data() + batch_idx * kHeads * kQKVDim;
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
activations.even_odd.data(), q, pool);
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR {
Attn(q + head * kQKVDim, head, 0, thread);
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
kHeads * kQKVDim * kModelDim, x,
activations.even_odd.data(), kv, pool);
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}
const size_t num_tasks = kHeads * num_tokens;
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
float* HWY_RESTRICT q =
activations.q.data() + batch_idx * kHeads * kQKVDim;
Attn(q + head * kQKVDim, head, 0, batch_idx, thread);
});
}
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
// rearranging the weights.
float* HWY_RESTRICT att_out =
activations.att_out.data() + batch_idx * kHeads * kQKVDim;
float* HWY_RESTRICT layer_out =
activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out,
layer_weights->attention_output_biases.data(),
activations.even_odd.data(), layer_out, pool);
for (size_t head = 1; head < kHeads; ++head) {
float* HWY_RESTRICT head_out =
activations.att_post1.data() + head * kBatchSize * kModelDim;
MatVec<kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
att_out + head * kQKVDim,
activations.even_odd.data(), head_out, pool);
AddFrom(head_out, layer_out, kModelDim);
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
// rearranging the weights.
float* HWY_RESTRICT att_out =
activations.att_out.data() + batch_idx * kHeads * kQKVDim;
float* HWY_RESTRICT layer_out =
activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out,
layer_weights->attention_output_biases.data(),
activations.even_odd.data(), layer_out, pool);
for (size_t head = 1; head < kHeads; ++head) {
float* HWY_RESTRICT head_out =
activations.att_post1.data() + head * kBatchSize * kModelDim;
MatVec<kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
att_out + head * kQKVDim,
activations.even_odd.data(), head_out, pool);
AddFrom(head_out, layer_out, kModelDim);
}
}
}
template <size_t kBatchSize, typename LayerT, typename TConfig>
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
size_t batch_idx, const LayerT* layer_weights,
size_t num_tokens, const LayerT* layer_weights,
hwy::ThreadPool& pool) {
HWY_DASSERT(batch_idx < kBatchSize);
HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
float* HWY_RESTRICT even_odd = activations.even_odd.data();
{
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
PROFILER_ZONE("Gen.FFW.GatedGELU");
const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
@ -839,11 +876,15 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
}
PROFILER_ZONE("Gen.FFW\\GatedGELU");
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
PROFILER_ZONE("Gen.FFW\\GatedGELU");
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
}
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
@ -898,24 +939,26 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data() + token_idx * kModelDim,
kModelDim);
if (type == LayerAttentionType::kGemma) {
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations,
layer_weights, kv_cache, pool);
} else {
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations,
layer_weights, kv_cache, pool);
}
}
if (type == LayerAttentionType::kGemma) {
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
layer_weights, kv_cache, pool);
} else {
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
layer_weights, kv_cache, pool);
}
// TODO: sink the loop into these functions, i.e. make them MatMul.
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
pool.Run(0, num_tokens, [&](const uint64_t token_idx,
size_t /*thread*/) HWY_ATTR {
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim);
RMSNorm(activations.x.data() + token_idx * kModelDim,
layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
kModelDim);
FFW<kBatchSize>(activations, token_idx, layer_weights, pool);
});
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim);
}
@ -957,16 +1000,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim);
if (type == LayerAttentionType::kGemma) {
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache,
Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache,
pool);
} else {
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights,
GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights,
kv_cache, pool);
}
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool);
FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
if (layers_output != nullptr) {
std::string block_name = "blocks." + std::to_string(layer);

View File

@ -25,6 +25,7 @@
#include <random>
#include <type_traits> // std::enable_if_t
#include "compression/compress.h" // CompressedArray
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
@ -92,6 +93,111 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
return kRowsPerStrip;
}
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kM, size_t kN, size_t kK>
HWY_INLINE void MatMul(const float* HWY_RESTRICT a, const float* HWY_RESTRICT b,
float* HWY_RESTRICT out) {
int i, j, k;
for (i = 0; i < kM; ++i) {
for (k = 0; k < kN; ++k) {
for (j = 0; j < kK; ++j) {
out[i * kK + j] += a[i * kN + k] * b[k * kK + j];
}
}
}
}
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf16;
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
const auto interleaved = hn::LoadU(dbf16, vec_aligned + i);
hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i);
hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df));
}
}
HWY_INLINE void ToEvenOddF32(const float* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
VF vec0, vec1;
for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) {
hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1);
hn::Store(vec0, df, out + i);
hn::Store(vec1, df, out + i + hn::Lanes(df));
}
}
// Simple version without tiling nor threading.
// even_odd is precomputed for the current thread.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT>
HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
PROFILER_ZONE("MatVecAddLoop");
const hn::ScalableTag<float> df;
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
const size_t row_ofs = mat_ofs + idx_row * kInner;
if constexpr (kAdd) {
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
Dot(df, mat, row_ofs, vec_aligned, kInner);
} else {
out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner);
}
}
}
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
template <bool kAdd, size_t kOuter, size_t kInner, typename VecT, typename AddT,
size_t kCapacity>
HWY_INLINE void MatVecAddLoop(
const CompressedArray<hwy::bfloat16_t, kCapacity>& mat,
const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add, float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
PROFILER_ZONE("MatVecAddLoop");
constexpr bool kVecIsEvenOdd = true;
const hn::ScalableTag<float> df;
ToEvenOddF32(vec_aligned, kInner, even_odd);
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
const size_t row_ofs = mat_ofs + idx_row * kInner;
if constexpr (kAdd) {
out[idx_row] = hwy::ConvertScalarTo<float>(add[idx_row]) +
Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
} else {
out[idx_row] = Dot<kVecIsEvenOdd>(df, mat, row_ofs, even_odd, kInner);
}
}
}
#endif
// even_odd is precomputed for the current thread.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out) {
MatVecAddLoop</*kAdd=*/false, kOuter, kInner>(
mat, mat_ofs, vec_aligned, /*add=*/static_cast<VecT*>(nullptr), even_odd,
out);
}
// Simple version without tiling nor threading, but two offsets/outputs.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT>
@ -120,25 +226,40 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
}
}
// Simple version without tiling nor threading, but two offsets/outputs.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0,
const size_t mat_ofs1,
const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1) {
TwoOfsMatVecAddLoop</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
out0, out1);
}
namespace detail {
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate
// of the tile is r0, c0.
template <class DF, typename ArrayT, typename VecT>
template <bool kVecEO, class DF, typename ArrayT, typename VecT>
HWY_INLINE void AccumulatePartialDotProducts(
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
size_t c0, size_t num_rows, size_t num_cols,
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
out[idx_row] += Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
out[idx_row] +=
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
// Same as above, but sets out[i] to the first partial dot product +
// init (if kInit), which avoids having to zero-initialize and accumulate.
template <bool kInit, class DF, typename ArrayT, typename VecT, typename InitT>
// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial
// dot product + init (if kInit), which avoids having to zero-initialize and
// accumulate.
template <bool kVecEO, bool kInit, class DF, typename ArrayT, typename VecT,
typename InitT>
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
size_t mat_ofs, size_t mat_stride,
size_t r0, size_t c0,
@ -149,10 +270,12 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
if constexpr (kInit) {
out[idx_row] = hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
out[idx_row] =
hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
} else {
out[idx_row] = Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
out[idx_row] =
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
}
@ -161,7 +284,8 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
// horizontal strip of the entire matrix); the result is the full dot product
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store
// into in out[r - r0].
template <bool kAdd, class DF, typename ArrayT, typename VecT, typename AddT>
template <bool kVecEO, bool kAdd, class DF, typename ArrayT, typename VecT,
typename AddT>
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
size_t mat_ofs, size_t mat_stride,
size_t r0, size_t num_rows,
@ -170,42 +294,37 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
float* HWY_RESTRICT out) {
// Tall and skinny: set `out` to the single dot product.
if (mat_stride < MaxCols()) {
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, mat_stride, vec_aligned, add,
out);
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0,
0, num_rows, mat_stride,
vec_aligned, add, out);
return;
}
// We have at least MaxCols, so start by setting `out` to that:
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, MaxCols(), vec_aligned, add, out);
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, MaxCols(), vec_aligned,
add, out);
// For further multiples of MaxCols, accumulate. Remainders handled below.
size_t c0 = MaxCols();
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
MaxCols(), vec_aligned, out);
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
num_rows, MaxCols(), vec_aligned, out);
}
if (c0 < mat_stride) { // Final cols
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
mat_stride - c0, vec_aligned, out);
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
num_rows, mat_stride - c0, vec_aligned,
out);
}
}
} // namespace detail
// Stores dot products of rows with `vec_aligned` + add the values from `add`
// (if kAdd), then stores them to `out`.
// `even_odd` has kInner elements for each thread.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT>
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd");
template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT even_odd,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
@ -223,9 +342,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add,
out + r0);
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add,
out + r0);
});
// Remaining rows
@ -233,18 +352,47 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
num_rows, vec_aligned, add, out + r0);
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0);
}
}
} // namespace detail
// Stores dot products of rows with `vec_aligned` + add the values from `add`
// (if kAdd), then stores them to `out`.
//
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
typename VecT, typename AddT>
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd");
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
if constexpr (CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd &&
hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()) {
ToEvenOddF32(vec_aligned, kInner, even_odd);
detail::MatVecAddInner</*kVecIsEvenOdd=*/true, kAdd, kOuter, kInner>(
mat, mat_ofs, even_odd, add, even_odd, out, pool);
return;
}
#endif
detail::MatVecAddInner</*kVecIsEvenOdd=*/false, kAdd, kOuter, kInner>(
mat, mat_ofs, vec_aligned, add, even_odd, out, pool);
}
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
MatVecAdd</*kAdd=*/false, kOuter, kInner, ArrayT, VecT, VecT>(
mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out, pool);
MatVecAdd</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned,
/*add=*/static_cast<VecT*>(nullptr),
even_odd, out, pool);
}
template <class D, HWY_IF_F32_D(D)>
@ -366,17 +514,18 @@ HWY_NOINLINE void TwoMatVecAdd(
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
constexpr bool kVecIsEvenOdd = false;
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("TwoMatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add0,
out0 + r0);
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add1,
out1 + r0);
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0,
out0 + r0);
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1,
out1 + r0);
});
// Remaining rows
@ -384,9 +533,9 @@ HWY_NOINLINE void TwoMatVecAdd(
if (r0 < kOuter) {
PROFILER_ZONE("TwoMatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kAdd>(
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
detail::FullDotProductsForStrip<kAdd>(
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
}
}

View File

@ -17,6 +17,8 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <algorithm>
#include <array>
#include <random>
@ -376,6 +378,25 @@ CompressedArray<float, kOuter * kInner> GenerateMat(size_t offset) {
return mat;
}
template <size_t kOuter, size_t kInner>
CompressedArray<float, kOuter * kInner> GenerateZeroMat(size_t offset) {
hwy::ThreadPool pool(static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 4)));
gcpp::CompressWorkingSet ws;
CompressedArray<float, kOuter * kInner> mat;
std::array<float, kOuter * kInner> content;
pool.Run(0, kOuter, [&](const size_t i, size_t thread) {
for (size_t j = 0; j < kInner; j++) {
content[i * kInner + j] = 0.0f;
}
});
Compress(content, ws, mat, pool);
mat.set_scale(1.0f);
return mat;
}
template <size_t length>
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
@ -386,6 +407,25 @@ hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
return vec;
}
// A simple matrix multiplication. No optimization / tiling.
template <size_t kM, size_t kN, size_t kK>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
const hwy::AlignedFreeUniquePtr<float[]>& a,
const hwy::AlignedFreeUniquePtr<float[]>& b) {
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kM * kK);
hwy::ZeroBytes(out.get(), kM * kK * sizeof(float));
int i, j, k;
for (i = 0; i < kM; ++i) {
for (j = 0; j < kK; ++j) {
for (k = 0; k < kN; ++k) {
out[i * kK + j] += a[i * kN + k] * b[k * kK + j];
}
}
}
return out;
}
template <size_t kOuter, size_t kInner>
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
const CompressedArray<float, kOuter * kInner>& mat,
@ -417,6 +457,52 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
}
}
template <typename MatT>
void AssertClose(const hwy::AlignedFreeUniquePtr<MatT[]>& expected,
const hwy::AlignedFreeUniquePtr<MatT[]>& actual, size_t num) {
for (size_t idx = 0; idx < num; idx++) {
double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
const double tolerance =
expected_value * 20 * 1.0 / (1ULL << hwy::MantissaBits<MatT>());
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
expected_value, idx, actual_value);
HWY_ASSERT(0);
}
}
}
void TestMatMul() {
hwy::ThreadPool pool(0);
constexpr size_t kM = 128 * 3; // 384
constexpr size_t kK = 128 * 5; // 640
constexpr size_t kN = 128 * 6; // 768
CompressedArray<float, kM * kN> a1 = GenerateMat<kM, kN>(0);
CompressedArray<float, kN * kK> b1 = GenerateMat<kN, kK>(0);
hwy::AlignedFreeUniquePtr<float[]> a = hwy::AllocateAligned<float>(kM * kN);
Decompress(a1, 0, a.get(), kM * kN);
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kN * kK);
Decompress(b1, 0, b.get(), kN * kK);
hwy::AlignedFreeUniquePtr<float[]> expected_out1 =
SimpleMatMul<kM, kN, kK>(a, b);
CompressedArray<float, kM * kK> compressed_c = GenerateZeroMat<kM, kK>(0);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
Decompress(compressed_c, 0, c.get(), kM * kK);
MatMul<kM, kN, kK>(a.get(), b.get(), c.get());
AssertClose(expected_out1, c, kM * kK);
}
void TestMatVecAdd() {
hwy::ThreadPool pool(0);
constexpr size_t kOuter = 128 * 3;
@ -518,6 +604,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatMul);
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);