Refactor/cleanup, remove even_odd

* New compression/shared.h, remove sfp.h
* Remove unused DistortionStats b_l1_
* Move exact arithmetic functions into fp_arith
* Remove even_odd optimization for MatVec (mostly unused)
* use BF16 typedef more widely
* Add kMaxSFP constant

PiperOrigin-RevId: 670996386
This commit is contained in:
Jan Wassenberg 2024-09-04 09:24:39 -07:00 committed by Copybara-Service
parent 07c34cb18a
commit c29e9752c7
22 changed files with 423 additions and 261 deletions

View File

@ -48,6 +48,15 @@ cc_library(
],
)
# Avoids circular dependency: fp_arith-inl -> compress-inl -> ops-inl
cc_library(
name = "fp_arith",
textual_hdrs = ["ops/fp_arith-inl.h"],
deps = [
"@hwy//:hwy",
],
)
cc_library(
name = "ops",
hdrs = [
@ -61,6 +70,7 @@ cc_library(
],
deps = [
":allocator",
":fp_arith",
":threading",
"//compression:compress",
"//compression:sfp",

View File

@ -46,8 +46,8 @@ set(SOURCES
compression/io.h
compression/nuq.h
compression/nuq-inl.h
compression/sfp.h
compression/sfp-inl.h
compression/shared.h
compression/weights_raw.h
backprop/activations.h
backprop/backward.cc
@ -155,6 +155,10 @@ set(GEMMA_TEST_FILES
backprop/backward_test.cc
backprop/backward_scalar_test.cc
backprop/optimize_test.cc
compression/compress_test.cc
compression/distortion_test.cc
compression/sfp_test.cc
compression/nuq_test.cc
ops/dot_test.cc
ops/ops_test.cc
ops/matmul_test.cc

View File

@ -28,6 +28,7 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "util/allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

View File

@ -24,9 +24,9 @@
#include <vector>
#include "backprop/activations.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "util/allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -40,9 +40,10 @@
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#endif
#include "hwy/highway.h"
// After highway.h
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
@ -110,7 +111,7 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<(kHeads + 2) * kQKVDim, kModelDim>(
weights.qkv_einsum_w, 0,
activations.pre_att_rms_out.data() + pos * kModelDim, nullptr,
activations.pre_att_rms_out.data() + pos * kModelDim,
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
}
const size_t num_tasks = kHeads * num_tokens;
@ -174,7 +175,7 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
MatVec<kModelDim, kQKVDim>(
weights.attn_vec_einsum_w, head * kModelDim * kQKVDim,
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
nullptr, activations.att_post1.data() + pos * kModelDim, pool);
activations.att_post1.data() + pos * kModelDim, pool);
AddFrom(activations.att_post1.data() + pos * kModelDim,
activations.attention_out.data() + pos * kModelDim, kModelDim);
}
@ -192,7 +193,7 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kFFHiddenDim * 2, kModelDim>(
weights.gating_einsum_w, 0,
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr,
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim,
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -215,7 +216,7 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
MatVec<kModelDim, kFFHiddenDim>(
weights.linear_w, 0,
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
nullptr, output + pos * kModelDim, pool);
output + pos * kModelDim, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.attention_out.data() + pos * kModelDim,
@ -265,7 +266,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kVocabSize, kModelDim>(
weights.embedder_input_embedding, 0,
forward.final_norm_output.data() + pos * kModelDim, nullptr,
forward.final_norm_output.data() + pos * kModelDim,
forward.logits.data() + pos * kVocabSize, pool);
}

View File

@ -64,6 +64,7 @@ cc_test(
srcs = ["distortion_test.cc"],
deps = [
":distortion",
":shared",
"@googletest//:gtest_main",
"//:test_util",
"@hwy//:hwy",
@ -72,11 +73,19 @@ cc_test(
],
)
cc_library(
name = "shared",
hdrs = ["shared.h"],
deps = [
"@hwy//:hwy",
],
)
cc_library(
name = "sfp",
hdrs = ["sfp.h"],
textual_hdrs = ["sfp-inl.h"],
deps = [
":shared",
"@hwy//:hwy",
],
)
@ -93,6 +102,7 @@ cc_test(
deps = [
":distortion",
":sfp",
":shared",
"@googletest//:gtest_main",
"//:ops",
"//:test_util",
@ -108,6 +118,7 @@ cc_library(
textual_hdrs = ["nuq-inl.h"],
deps = [
":sfp",
":shared",
"@hwy//:hwy",
"@hwy//hwy/contrib/sort:vqsort",
],
@ -127,6 +138,7 @@ cc_test(
":distortion",
":nuq",
":sfp",
":shared",
"@googletest//:gtest_main",
"//:test_util",
"@hwy//:hwy",
@ -137,11 +149,7 @@ cc_test(
cc_library(
name = "compress",
hdrs = [
"compress.h",
"nuq.h",
"sfp.h",
],
hdrs = ["compress.h"],
textual_hdrs = [
"compress-inl.h",
],
@ -151,12 +159,35 @@ cc_library(
":io",
":nuq",
":sfp",
":shared",
"//:fp_arith",
"@hwy//:hwy",
"@hwy//:stats",
"@hwy//:thread_pool",
],
)
cc_test(
name = "compress_test",
size = "small",
srcs = ["compress_test.cc"],
features = ["fully_static_link"],
linkstatic = True,
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":compress",
":distortion",
"@googletest//:gtest_main",
"//:test_util",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
# For internal experimentation
cc_library(
name = "analyze",
@ -190,6 +221,7 @@ cc_binary(
srcs = ["compress_weights.cc"],
deps = [
":compress",
":shared",
":weights_raw",
# Placeholder for internal dep, do not remove.,
"//:args",

View File

@ -30,10 +30,10 @@
#include "compression/blob_store.h"
#include "compression/io.h"
#include "compression/nuq.h"
#include "compression/sfp.h"
#include "compression/shared.h"
// IWYU pragma: end_exports
#include "compression/distortion.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#if COMPRESS_STATS
#include "hwy/stats.h"
@ -41,29 +41,20 @@
namespace gcpp {
using BF16 = hwy::bfloat16_t;
static inline const char* TypeName(float) { return "f32"; }
static inline const char* TypeName(BF16) { return "b16"; }
static inline const char* TypeName(SfpStream) { return "SFP"; }
static inline const char* TypeName(NuqStream) { return "NUQ"; }
namespace detail {
// How many MatT are required to store `capacity` weights. For all but
// NuqStream, this is the same as `capacity`. For use by CompressedArray.
// Returns the number of `MatT` elements required to store `capacity` values,
// which must not be zero.
template <typename MatT>
constexpr size_t CompressedArrayLen(size_t capacity) {
return capacity;
}
template <>
constexpr size_t CompressedArrayLen<NuqStream>(size_t capacity) {
constexpr size_t CompressedArrayElements(size_t capacity) {
if constexpr (hwy::IsSame<hwy::RemoveCvRef<MatT>, NuqStream>()) {
return NuqStream::PackedEnd(capacity);
}
} // namespace detail
// Returns the number of bytes required to store a compressed array with the
// given type and capacity.
template <typename MatT>
constexpr size_t CompressedArraySize(size_t capacity) {
return detail::CompressedArrayLen<MatT>(capacity) * sizeof(MatT);
} else {
return capacity;
}
}
// Compressed representation of floating-point elements. The array length may
@ -71,10 +62,6 @@ constexpr size_t CompressedArraySize(size_t capacity) {
// implemented in SIMD code and are thus non-member functions.
template <typename MatT, size_t kCapacity>
class CompressedArray {
static constexpr size_t NumCompressed() {
return detail::CompressedArrayLen<MatT>(kCapacity);
}
public:
using value_type = MatT;
@ -100,11 +87,11 @@ class CompressedArray {
constexpr size_t size() const { return kCapacity; }
constexpr size_t CompressedSize() const {
return NumCompressed() * sizeof(MatT);
return data_.size() * sizeof(MatT);
}
private:
std::array<MatT, NumCompressed()> data_;
std::array<MatT, CompressedArrayElements<MatT>(kCapacity)> data_;
// Blobs are at least kBlobAlign bytes anyway.
float scale_[kBlobAlign / sizeof(float)];
};

View File

@ -0,0 +1,14 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

View File

@ -21,9 +21,9 @@
#define HWY_TARGET_INCLUDE \
"compression/compress_weights.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors.
#include "compression/compress-inl.h"
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
#define GEMMA_COMPRESS_WEIGHTS_ONCE
@ -38,9 +38,11 @@
#include <thread> // NOLINT
#include "compression/io.h" // Path
#include "compression/shared.h"
#include "compression/weights_raw.h"
#include "gemma/common.h" // Model
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -57,11 +59,10 @@ float ScaleWeights(float* data, size_t len) {
for (size_t i = 0; i < len; ++i) {
maxabs = std::max(maxabs, std::abs(data[i]));
}
const float kMaxRange = 1.875f;
if (maxabs <= kMaxRange) {
if (maxabs <= kMaxSFP) {
return 1.0f;
}
const float scale = maxabs / kMaxRange;
const float scale = maxabs / kMaxSFP;
const float inv_scale = 1.0f / scale;
for (size_t i = 0; i < len; ++i) {
data[i] *= inv_scale;

View File

@ -30,9 +30,12 @@
namespace gcpp {
// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
// despite floating-point rounding. `sum` is already the best estimate, so do
// not actually add `err` to it. Knuth98/Moller65. Unlike Fast2Sum [Dekker71],
// this does not require any relative ordering of the exponents of a and b.
// despite floating-point rounding. `sum` is already the best estimate for the
// addition, so do not directly add `err` to it.
//
// Knuth98/Moller65. Unlike FastTwoSum, this does not require any relative
// ordering of the exponents of a and b. 6 ops.
// TODO: move to and use in Highway stats.h?
template <typename T, HWY_IF_FLOAT3264(T)>
static inline T TwoSum(T a, T b, T& err) {
const T sum = a + b;
@ -88,7 +91,6 @@ class DistortionStats {
const float l1f = hwy::ScalarAbs(original - distorted);
const double l1 = static_cast<double>(l1f);
s_l1_.Notify(l1f);
b_l1_.Notify(HWY_MIN(99, static_cast<int>(l1f * 1E4)));
if (l1f != 0.0f) {
l1_.push_back(l1f);
}
@ -102,7 +104,7 @@ class DistortionStats {
// as much as an actual sign flip, so do not count them.
n_sign_flip_ +=
((original < 0.0f) != (distorted < 0.0f)) && !rounded_to_zero;
n_exact_ += (l1f == 0.0f);
n_exact_ += (original == distorted);
n_rounded_to_zero += rounded_to_zero;
}
@ -122,7 +124,6 @@ class DistortionStats {
void Assimilate(const DistortionStats& other) {
s_original_.Assimilate(other.s_original_);
s_l1_.Assimilate(other.s_l1_);
b_l1_.Assimilate(other.b_l1_);
sum_l1_.Assimilate(other.sum_l1_);
sum_l1_rounded_.Assimilate(other.sum_l1_rounded_);
l1_.insert(l1_.end(), other.l1_.begin(), other.l1_.end());
@ -204,7 +205,6 @@ class DistortionStats {
private:
hwy::Stats s_original_;
hwy::Stats s_l1_;
hwy::Bins<100> b_l1_;
CascadedSummation<double> sum_l1_; // all
CascadedSummation<double> sum_l1_rounded_; // only if rounded_to_zero
std::vector<float> l1_;

View File

@ -17,6 +17,7 @@
#include <stdio.h>
#include "compression/shared.h"
#include "util/test_util.h"
#include "hwy/nanobenchmark.h"
#include "hwy/tests/hwy_gtest.h"
@ -74,13 +75,13 @@ TEST(DistortionTest, TestDilution) {
HWY_ASSERT(IsNear(0.001, stats.WeightedAverageL1()));
// Now add a large difference:
stats.Notify(1.875f - 0.0625f, 1.875f); // max magnitude, 3-bit mantissa
stats.Notify(kMaxSFP - 0.0625f, kMaxSFP); // max magnitude, 3-bit mantissa
// .. WeightedAverageL1 is closer to it.
HWY_ASSERT(IsInside(0.020, 0.025, stats.WeightedAverageL1()));
// Add a small and large difference:
stats.Notify((1.75f - 0.125f) / 1024, 1.75f / 1024); // small, 2-bit mantissa
stats.Notify(-1.875f + 0.0625f, -1.875f); // larger negative
stats.Notify(-kMaxSFP + 0.0625f, -kMaxSFP); // larger negative
// .. SNR is still barely affected.
HWY_ASSERT(IsInside(890.0, 900.0, stats.GeomeanValueDivL1()));
// .. WeightedAverageL1 is higher after another large error.

View File

@ -21,7 +21,7 @@
#include <stdint.h>
#include "compression/nuq.h"
#include "compression/sfp.h"
#include "compression/shared.h"
#include "hwy/base.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_

View File

@ -40,9 +40,8 @@ static constexpr size_t kGroupSize = 256;
// Points to the *start* of a NUQ stream. Aligning the allocation (see
// aligned_allocator.h) may be speed up decoding but is not required.
//
// See go/streaming-weight-decode for background and design. Layout: first one
// table of kClusters entries per group, in ascending order of group index,
// then two packed indices per byte.
// Layout: first one table of kClusters entries per group, in ascending order
// of group index, then two packed indices per byte.
//
// Indices are stored in-order to enable vector-length agnostic decode, because
// streams may be persisted to disk and used by other CPUs.
@ -54,26 +53,23 @@ static constexpr size_t kGroupSize = 256;
#pragma pack(push, 1)
struct NuqStream {
// Returns offset of packed indices from the start of the stream. This matches
// the (padded) total table size because table entries are bytes. `capacity`
// is already a multiple of `kGroupSize`.
// the (padded) total table size because table entries are bytes.
static constexpr size_t PackedStart(size_t capacity) {
// Round up to avoid cache-line splits when loading indices. No effect on
// size as long as capacity / kGroupSize is a multiple of 4.
return hwy::RoundUpTo((capacity / kGroupSize) * kClusters, 64);
return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64);
}
// Returns number of NuqStream to allocate for the stream, which matches its
// size in bytes. `capacity` is already a multiple of `kGroupSize`.
static constexpr size_t PackedEnd(size_t capacity) {
return PackedStart(capacity) + capacity / 2; // two 4-bit indices per byte.
return PackedStart(capacity) + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte
}
uint8_t byte;
};
#pragma pack(pop)
static inline const char* TypeName(NuqStream) { return "NUQ"; }
// Storage for dynamic programming. There are two matrices; we use separate
// allocations to avoid type punning.
template <class T>
@ -101,7 +97,7 @@ struct ClusterBuf {
num = new_num;
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
centers = hwy::AllocateAligned<float>(num_groups * kClusters);
idx = hwy::AllocateAligned<uint16_t>(num);
idx = hwy::AllocateAligned<uint16_t>(hwy::RoundUpTo(num, kGroupSize));
}
AlignedMatrix<float> d;

View File

@ -46,7 +46,7 @@ class SbsWriterImpl : public WriterInterface {
SbsWriterImpl() : pool_(0), compressor_(pool_) {}
void Insert(std::string name, absl::Span<const float> weights) override {
const size_t out_size = CompressedArraySize<SfpStream>(weights.size());
const size_t out_size = CompressedArrayElements<SfpStream>(weights.size());
sfp_streams_.push_back(std::vector<SfpStream>(out_size));
compressor_.Insert<SfpStream>(name.data(), weights.data(), weights.size(),
working_set_, out_size,
@ -54,7 +54,7 @@ class SbsWriterImpl : public WriterInterface {
}
void InsertNUQ(std::string name, absl::Span<const float> weights) override {
const size_t out_size = CompressedArraySize<NuqStream>(weights.size());
const size_t out_size = CompressedArrayElements<NuqStream>(weights.size());
nuq_streams_.push_back(std::vector<NuqStream>(out_size));
compressor_.Insert<NuqStream>(name.data(), weights.data(), weights.size(),
working_set_, out_size,
@ -64,7 +64,7 @@ class SbsWriterImpl : public WriterInterface {
void InsertBfloat16(std::string name,
absl::Span<const float> weights) override {
const size_t out_size =
CompressedArraySize<hwy::bfloat16_t>(weights.size());
CompressedArrayElements<hwy::bfloat16_t>(weights.size());
bf16_streams_.push_back(std::vector<hwy::bfloat16_t>(out_size));
compressor_.Insert<hwy::bfloat16_t>(name.data(), weights.data(),
weights.size(), working_set_, out_size,

View File

@ -20,7 +20,7 @@
#include <stddef.h>
#include <stdint.h>
#include "compression/sfp.h"
#include "compression/shared.h"
#include "hwy/base.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_

View File

@ -1,51 +0,0 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
// It supports seeking at a granularity of 1, decoding to bf16/f32, and a
// fused decode/dot product with bf16/f32 vectors.
#include <stdint.h>
namespace gcpp {
// Points to the *start* of an SFP stream. Values are stored in-order to enable
// vector-length agnostic seeking, because streams may be written to disk for
// loading on other CPUs.
//
// Characteristics:
// - 24-bit dynamic range, with max exponent 2^0.
// - 3 bit mantissa for values >= 2^-7, otherwise 2.
//
// This is faster to decode than a straightforward implementation of eXmY, in
// part because SFP does not require subnormals. Unlike OCP MX, it also does not
// require side information (shared exponents).
//
// Although the representation could probably be shrunk to 6-7 bits, more
// savings can be had by non-uniform clustering - see nuq.h.
#pragma pack(push, 1)
struct SfpStream {
uint8_t byte;
};
#pragma pack(pop)
static inline const char* TypeName(SfpStream) { return "SFP"; }
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_

View File

@ -18,8 +18,6 @@
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include "compression/sfp.h"
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
@ -27,20 +25,21 @@
#include <set>
#include "compression/distortion.h"
#include "compression/shared.h"
#include "util/test_util.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/timer.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Any highway.h must come after foreach_target.h
#include "hwy/highway.h"
// After highway.h
#include "compression/sfp-inl.h"
#include "ops/dot-inl.h"
#include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();

94
compression/shared.h Normal file
View File

@ -0,0 +1,94 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Definitions shared between the public compress-inl.h interface and the
// sfp-inl.h and nuq-inl.h implementation details.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
#include <stddef.h>
#include "hwy/base.h" // hwy::bfloat16_t
namespace gcpp {
using BF16 = hwy::bfloat16_t;
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
//
// Characteristics:
// - 24-bit dynamic range, with max exponent 2^0.
// - 3 bit mantissa for values >= 2^-7, otherwise 2.
//
// A pointer to this is the *start* of an SFP stream. Values are stored
// in-order to enable vector-length agnostic seeking, because streams may be
// written to disk for loading on other CPUs.
//
// This is faster to decode than a straightforward implementation of eXmY, in
// part because SFP does not require subnormals. Unlike OCP MX, it also does not
// require side information (shared exponents).
//
// Although the representation could probably be shrunk to 6-7 bits, more
// savings can be had by non-uniform clustering - see nuq.h.
#pragma pack(push, 1)
struct SfpStream {
uint8_t byte;
};
#pragma pack(pop)
// Largest possible input magnitude: 1.111 * 2^0. This could be increased by
// shifting the value range (exponent bias).
constexpr float kMaxSFP = 1.875f;
// Non-owning view of packed elements. Shortens argument lists.
//
// Callers typically also pass an `ofs` starting offset. This is not folded
// into `ptr` because NUQ consists of two separate streams. To discourage direct
// use of `ptr` without that offset, we define a separate class instead of
// reusing `hwy::Span`.
template <typename Packed>
struct PackedSpan {
void BoundsCheck(size_t packed_ofs, size_t num) const {
HWY_DASSERT(packed_ofs + num <= size);
(void)size;
}
Packed* HWY_RESTRICT ptr;
size_t size; // for BoundsCheck and nuq-inl.h HWY_ASSERT.
};
// Avoids spelling out the template parameter in every call.
template <typename Packed>
HWY_INLINE PackedSpan<Packed> MakeSpan(Packed* ptr, size_t size) {
return {ptr, size};
}
template <typename Packed>
HWY_INLINE PackedSpan<const Packed> MakeConstSpan(Packed* ptr, size_t size) {
return {ptr, size};
}
// "Implicit" conversion from a PackedSpan<T> to PackedSpan<const T>, used in
// `RMSNormInplace` and compression tests.
template <typename Packed>
HWY_INLINE PackedSpan<const Packed> MakeConst(PackedSpan<Packed> packed) {
return {packed.ptr, packed.size};
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_

View File

@ -55,11 +55,6 @@ struct Activations {
// Rope
RowVectorBatch<float> inv_timescale;
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
// per-thread storage.
// TODO: remove once MatVec is no longer used.
RowVectorBatch<float> even_odd;
MatMulEnv env;
// Multi-Head Attention?
@ -123,9 +118,6 @@ struct Activations {
inv_timescale = CreateInvTimescale<TConfig>();
const size_t num_lp = pools.NumLP();
even_odd = RowVectorBatch<float>(1, kModelDim * num_lp);
env = MatMulEnv(pools);
}
};

View File

@ -192,8 +192,7 @@ HWY_NOINLINE void GriffinRecurrent(
float* out_ptr = activations.att_sums.Batch(batch_idx);
MatVecAdd<kModelDim, kModelDim>(
layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin.linear_out_biases.data_scale1(),
activations.even_odd.All(), out_ptr, pool);
layer_weights->griffin.linear_out_biases.data_scale1(), out_ptr, pool);
}
}
@ -283,8 +282,8 @@ class GemmaAttention {
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
activations_.even_odd.All(), kv, pool_);
layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, kv,
pool_);
}
}
}

152
ops/fp_arith-inl.h Normal file
View File

@ -0,0 +1,152 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
// Building blocks for floating-point arithmetic.
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE
#endif
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
//------------------------------------------------------------------------------
// Exact multiplication
namespace detail {
// Returns non-overlapping `x` and `y` such that `x + y` = `f` and |x| >= |y|.
// Notation from Algorithm 3.1 in Handbook of Floating-Point Arithmetic. 4 ops.
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
static HWY_INLINE void VeltkampSplit(DF df, VF a, VF& x, VF& y) {
using TF = hn::TFromD<DF>;
constexpr int t = hwy::MantissaBits<TF>() + 1; // = -log2(epsilon)
constexpr int s = hwy::DivCeil(t, 2);
const VF factor = hn::Set(df, hwy::ConvertScalarTo<TF>((1ULL << s) + 1));
const VF c = hn::Mul(factor, a);
x = hn::Sub(c, hn::Sub(c, a));
y = hn::Sub(a, x);
}
} // namespace detail
// Returns `prod` and `err` such that `prod + err` is exactly equal to `a * b`,
// despite floating-point rounding, assuming that `err` is not subnormal, i.e.,
// the sum of exponents >= min exponent + mantissa bits. 2..17 ops. Useful for
// compensated dot products and polynomial evaluation.
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
static HWY_INLINE VF TwoProducts(DF df, VF a, VF b, VF& err) {
const VF prod = hn::Mul(a, b);
if constexpr (HWY_NATIVE_FMA) {
err = hn::MulSub(a, b, prod);
} else {
// Non-FMA fallback: we assume these calculations do not overflow.
VF a1, a2, b1, b2;
detail::VeltkampSplit(df, a, a1, a2);
detail::VeltkampSplit(df, b, b1, b2);
const VF m = hn::Sub(prod, hn::Mul(a1, b1));
const VF n = hn::Sub(m, hn::Mul(a2, b1));
const VF o = hn::Sub(n, hn::Mul(a1, b2));
err = hn::Sub(hn::Mul(a2, b2), o);
}
return prod;
}
//------------------------------------------------------------------------------
// Exact addition
// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
// despite floating-point rounding. `sum` is already the best estimate for the
// addition, so do not directly add `err` to it. `UpdateCascadedSums` instead
// accumulates multiple `err`, which are then later added to the total `sum`.
//
// Knuth98/Moller65. Unlike FastTwoSums, this does not require any relative
// ordering of the exponents of a and b. 6 ops.
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
static HWY_INLINE VF TwoSums(DF /*df*/, VF a, VF b, VF& err) {
const VF sum = hn::Add(a, b);
const VF a2 = hn::Sub(sum, b);
const VF b2 = hn::Sub(sum, a2);
const VF err_a = hn::Sub(a, a2);
const VF err_b = hn::Sub(b, b2);
err = hn::Add(err_a, err_b);
return sum;
}
// As above, but only exact if the exponent of `a` >= that of `b`, which is the
// case if |a| >= |b|. Dekker71, also used in Kahan65 compensated sum. 3 ops.
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
static HWY_INLINE VF FastTwoSums(DF /*df*/, VF a, VF b, VF& err) {
const VF sum = hn::Add(a, b);
const VF b2 = hn::Sub(sum, a);
err = hn::Sub(b, b2);
return sum;
}
//------------------------------------------------------------------------------
// Cascaded summation (twice working precision)
// Accumulates numbers with about twice the precision of T using 7 * n FLOPS.
// Rump/Ogita/Oishi08, Algorithm 6.11 in Handbook of Floating-Point Arithmetic.
//
// Because vectors generally cannot be wrapped in a class, we use functions.
// `sum` and `sum_err` must be initially zero. Each lane is an independent sum.
// To reduce them into a single result, use `ReduceCascadedSum`.
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
void UpdateCascadedSums(DF df, VF v, VF& sum, VF& sum_err) {
VF err;
sum = TwoSums(df, sum, v, err);
sum_err += err;
}
// Combines two cascaded sum vectors, typically from unrolling/parallelization.
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
void AssimilateCascadedSums(DF df, const VF& other_sum, const VF& other_sum_err,
VF& sum, VF& sum_err) {
UpdateCascadedSums(df, other_sum, sum, sum_err);
sum_err += other_sum_err;
}
// Reduces cascaded sums, to a single value. Slow, call outside of loops.
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
hn::TFromD<DF> ReduceCascadedSums(DF df, const VF sum, VF sum_err) {
const size_t N = hn::Lanes(df);
using TF = hn::TFromD<DF>;
TF total = TF{0.0};
TF total_err = TF{0.0};
for (size_t i = 0; i < N; ++i) {
TF err;
total = TwoSum(total, hn::ExtractLane(sum, i), err);
total_err += err;
}
return total + total_err + hn::ReduceSum(df, sum_err);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

View File

@ -115,15 +115,13 @@ void TestMatVecAdd() {
GenerateMat<float, kOuter, kInner>(0, pool);
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
hwy::AlignedFreeUniquePtr<float[]> even_odd =
hwy::AllocateAligned<float>(kInner * pool.NumWorkers());
hwy::AlignedFreeUniquePtr<float[]> expected_out =
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
hwy::AlignedFreeUniquePtr<float[]> actual_out =
hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
MatVecAdd<kOuter, kInner>(mat, 0, vec.get(), add.get(), even_odd.get(),
actual_out.get(), pool);
HWY_ASSERT(vec && add && expected_out && actual_out);
MatVecAdd<kOuter, kInner>(mat, 0, vec.get(), add.get(), actual_out.get(),
pool);
AssertClose<kOuter>(actual_out, expected_out);
}

View File

@ -21,8 +21,6 @@
#include <stdint.h>
#include <stdio.h>
#include "compression/compress.h"
#include "compression/sfp.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
@ -48,37 +46,6 @@ namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf16;
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
const auto interleaved = hn::LoadU(dbf16, vec_aligned + i);
hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i);
hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df));
}
}
HWY_INLINE void ToEvenOddF32(const float* HWY_RESTRICT vec_aligned,
const size_t size, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
VF vec0, vec1;
for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) {
hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1);
hn::Store(vec0, df, out + i);
hn::Store(vec1, df, out + i + hn::Lanes(df));
}
}
// Simple version without tiling nor threading, but two offsets/outputs and
// always with addition.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT,
@ -91,16 +58,15 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1) {
PROFILER_ZONE("TwoOfsMatVecAddLoop");
constexpr bool kVecEO = false;
const hn::ScalableTag<float> df;
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner;
const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner;
out0[idx_row] = hwy::ConvertScalarTo<float>(add0[idx_row]) +
Dot<kVecEO>(df, mat, row_ofs0, vec_aligned, kInner);
Dot<false>(df, mat, row_ofs0, vec_aligned, kInner);
out1[idx_row] = hwy::ConvertScalarTo<float>(add1[idx_row]) +
Dot<kVecEO>(df, mat, row_ofs1, vec_aligned, kInner);
Dot<false>(df, mat, row_ofs1, vec_aligned, kInner);
}
}
@ -125,7 +91,7 @@ namespace detail {
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
// of row i with `vec_aligned` and add into `out[i]`. The upper-left
// coordinate of the tile is r0, c0.
template <bool kVecEO, class DF, typename ArrayT, typename VecT>
template <class DF, typename ArrayT, typename VecT>
HWY_INLINE void AccumulatePartialDotProducts(
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
size_t c0, size_t num_rows, size_t num_cols,
@ -133,15 +99,14 @@ HWY_INLINE void AccumulatePartialDotProducts(
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
out[idx_row] +=
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
Dot<false>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial
// dot product + init (if kInit), which avoids having to zero-initialize and
// accumulate.
template <bool kVecEO, bool kInit, class DF, typename ArrayT, typename VecT,
typename InitT>
template <bool kInit, class DF, typename ArrayT, typename VecT, typename InitT>
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
size_t mat_ofs, size_t mat_stride,
size_t r0, size_t c0,
@ -154,10 +119,10 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
if constexpr (kInit) {
out[idx_row] =
hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
Dot<false>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
} else {
out[idx_row] =
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
Dot<false>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
}
}
}
@ -166,8 +131,7 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
// horizontal strip of the entire matrix); the result is the full dot product
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we
// store into in out[r - r0].
template <bool kVecEO, bool kAdd, class DF, typename ArrayT, typename VecT,
typename AddT>
template <bool kAdd, class DF, typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
size_t mat_ofs, size_t mat_stride,
size_t r0, size_t num_rows,
@ -176,56 +140,25 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
float* HWY_RESTRICT out) {
// Tall and skinny: set `out` to the single dot product.
if (mat_stride < MaxCols()) {
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0,
0, num_rows, mat_stride,
vec_aligned, add, out);
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, mat_stride, vec_aligned, add,
out);
return;
}
// We have at least MaxCols, so start by setting `out` to that:
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, MaxCols(), vec_aligned,
add, out);
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
num_rows, MaxCols(), vec_aligned, add, out);
// For further multiples of MaxCols, accumulate. Remainders handled below.
size_t c0 = MaxCols();
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
num_rows, MaxCols(), vec_aligned, out);
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
MaxCols(), vec_aligned, out);
}
if (c0 < mat_stride) { // Final cols
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
num_rows, mat_stride - c0, vec_aligned,
out);
}
}
template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
typename ArrayT, typename VecT, typename AddT>
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add,
out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0);
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
mat_stride - c0, vec_aligned, out);
}
}
@ -238,28 +171,30 @@ template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd");
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
using MatT = typename ArrayT::value_type;
// Sfp -> float does not benefit enough to recoup the cost of ToEvenOddF32.
if constexpr (CompressTraits<MatT>::kSupportsEvenOdd &&
hwy::IsSameEither<VecT, float, hwy::bfloat16_t>() &&
!(hwy::IsSame<MatT, SfpStream>() &&
hwy::IsSame<VecT, float>())) {
ToEvenOddF32(vec_aligned, kInner, even_odd);
detail::MatVecAddInner</*kVecIsEvenOdd=*/true, kAdd, kOuter, kInner>(
mat, mat_ofs, even_odd, add, out, pool);
return;
}
#else
(void)even_odd;
#endif
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
detail::MatVecAddInner</*kVecIsEvenOdd=*/false, kAdd, kOuter, kInner>(
mat, mat_ofs, vec_aligned, add, out, pool);
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add,
out + r0);
});
// Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip;
if (r0 < kOuter) {
PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
num_rows, vec_aligned, add, out + r0);
}
}
// With addition
@ -268,21 +203,19 @@ template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT,
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
return MatVecT</*kAdd=*/true, kOuter, kInner>(mat, mat_ofs, vec_aligned, add,
even_odd, out, pool);
out, pool);
}
// Without addition
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
const VecT* HWY_RESTRICT const vec_aligned,
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
hwy::ThreadPool& pool) {
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
MatVecT</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned,
/*add=*/static_cast<VecT*>(nullptr),
even_odd, out, pool);
out, pool);
}
// Two matrices, same vector
@ -300,17 +233,16 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT& mat0, const ArrayT& mat1,
const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
constexpr bool kVecIsEvenOdd = false;
// For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("TwoMatVec.lambda");
const size_t r0 = strip * kRowsPerStrip;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0,
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add0,
out0 + r0);
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1,
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, kInner, r0,
kRowsPerStrip, vec_aligned, add1,
out1 + r0);
});
@ -319,9 +251,9 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT& mat0, const ArrayT& mat1,
if (r0 < kOuter) {
PROFILER_ZONE("TwoMatVec remainder");
const size_t num_rows = kOuter - r0;
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
detail::FullDotProductsForStrip<kAdd>(
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
detail::FullDotProductsForStrip<kAdd>(
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
}
}