mirror of https://github.com/google/gemma.cpp.git
Refactor/cleanup, remove even_odd
* New compression/shared.h, remove sfp.h * Remove unused DistortionStats b_l1_ * Move exact arithmetic functions into fp_arith * Remove even_odd optimization for MatVec (mostly unused) * use BF16 typedef more widely * Add kMaxSFP constant PiperOrigin-RevId: 670996386
This commit is contained in:
parent
07c34cb18a
commit
c29e9752c7
10
BUILD.bazel
10
BUILD.bazel
|
|
@ -48,6 +48,15 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Avoids circular dependency: fp_arith-inl -> compress-inl -> ops-inl
|
||||||
|
cc_library(
|
||||||
|
name = "fp_arith",
|
||||||
|
textual_hdrs = ["ops/fp_arith-inl.h"],
|
||||||
|
deps = [
|
||||||
|
"@hwy//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ops",
|
name = "ops",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
|
@ -61,6 +70,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
|
":fp_arith",
|
||||||
":threading",
|
":threading",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,8 @@ set(SOURCES
|
||||||
compression/io.h
|
compression/io.h
|
||||||
compression/nuq.h
|
compression/nuq.h
|
||||||
compression/nuq-inl.h
|
compression/nuq-inl.h
|
||||||
compression/sfp.h
|
|
||||||
compression/sfp-inl.h
|
compression/sfp-inl.h
|
||||||
|
compression/shared.h
|
||||||
compression/weights_raw.h
|
compression/weights_raw.h
|
||||||
backprop/activations.h
|
backprop/activations.h
|
||||||
backprop/backward.cc
|
backprop/backward.cc
|
||||||
|
|
@ -155,6 +155,10 @@ set(GEMMA_TEST_FILES
|
||||||
backprop/backward_test.cc
|
backprop/backward_test.cc
|
||||||
backprop/backward_scalar_test.cc
|
backprop/backward_scalar_test.cc
|
||||||
backprop/optimize_test.cc
|
backprop/optimize_test.cc
|
||||||
|
compression/compress_test.cc
|
||||||
|
compression/distortion_test.cc
|
||||||
|
compression/sfp_test.cc
|
||||||
|
compression/nuq_test.cc
|
||||||
ops/dot_test.cc
|
ops/dot_test.cc
|
||||||
ops/ops_test.cc
|
ops/ops_test.cc
|
||||||
ops/matmul_test.cc
|
ops/matmul_test.cc
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,9 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
@ -40,9 +40,10 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
#include "ops/matvec-inl.h"
|
#include "ops/matvec-inl.h"
|
||||||
#include "ops/ops-inl.h"
|
#include "ops/ops-inl.h"
|
||||||
#include "hwy/highway.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -110,7 +111,7 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
MatVec<(kHeads + 2) * kQKVDim, kModelDim>(
|
MatVec<(kHeads + 2) * kQKVDim, kModelDim>(
|
||||||
weights.qkv_einsum_w, 0,
|
weights.qkv_einsum_w, 0,
|
||||||
activations.pre_att_rms_out.data() + pos * kModelDim, nullptr,
|
activations.pre_att_rms_out.data() + pos * kModelDim,
|
||||||
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
|
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
|
||||||
}
|
}
|
||||||
const size_t num_tasks = kHeads * num_tokens;
|
const size_t num_tasks = kHeads * num_tokens;
|
||||||
|
|
@ -174,7 +175,7 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
||||||
MatVec<kModelDim, kQKVDim>(
|
MatVec<kModelDim, kQKVDim>(
|
||||||
weights.attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
weights.attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||||
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
|
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
|
||||||
nullptr, activations.att_post1.data() + pos * kModelDim, pool);
|
activations.att_post1.data() + pos * kModelDim, pool);
|
||||||
AddFrom(activations.att_post1.data() + pos * kModelDim,
|
AddFrom(activations.att_post1.data() + pos * kModelDim,
|
||||||
activations.attention_out.data() + pos * kModelDim, kModelDim);
|
activations.attention_out.data() + pos * kModelDim, kModelDim);
|
||||||
}
|
}
|
||||||
|
|
@ -192,7 +193,7 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
MatVec<kFFHiddenDim * 2, kModelDim>(
|
MatVec<kFFHiddenDim * 2, kModelDim>(
|
||||||
weights.gating_einsum_w, 0,
|
weights.gating_einsum_w, 0,
|
||||||
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr,
|
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim,
|
||||||
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
|
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
|
||||||
}
|
}
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
|
|
@ -215,7 +216,7 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
||||||
MatVec<kModelDim, kFFHiddenDim>(
|
MatVec<kModelDim, kFFHiddenDim>(
|
||||||
weights.linear_w, 0,
|
weights.linear_w, 0,
|
||||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
|
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
|
||||||
nullptr, output + pos * kModelDim, pool);
|
output + pos * kModelDim, pool);
|
||||||
}
|
}
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
AddFrom(activations.attention_out.data() + pos * kModelDim,
|
AddFrom(activations.attention_out.data() + pos * kModelDim,
|
||||||
|
|
@ -265,7 +266,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
MatVec<kVocabSize, kModelDim>(
|
MatVec<kVocabSize, kModelDim>(
|
||||||
weights.embedder_input_embedding, 0,
|
weights.embedder_input_embedding, 0,
|
||||||
forward.final_norm_output.data() + pos * kModelDim, nullptr,
|
forward.final_norm_output.data() + pos * kModelDim,
|
||||||
forward.logits.data() + pos * kVocabSize, pool);
|
forward.logits.data() + pos * kVocabSize, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,7 @@ cc_test(
|
||||||
srcs = ["distortion_test.cc"],
|
srcs = ["distortion_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":distortion",
|
":distortion",
|
||||||
|
":shared",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//:test_util",
|
"//:test_util",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
|
@ -72,11 +73,19 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "shared",
|
||||||
|
hdrs = ["shared.h"],
|
||||||
|
deps = [
|
||||||
|
"@hwy//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "sfp",
|
name = "sfp",
|
||||||
hdrs = ["sfp.h"],
|
|
||||||
textual_hdrs = ["sfp-inl.h"],
|
textual_hdrs = ["sfp-inl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":shared",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -93,6 +102,7 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":distortion",
|
":distortion",
|
||||||
":sfp",
|
":sfp",
|
||||||
|
":shared",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//:ops",
|
"//:ops",
|
||||||
"//:test_util",
|
"//:test_util",
|
||||||
|
|
@ -108,6 +118,7 @@ cc_library(
|
||||||
textual_hdrs = ["nuq-inl.h"],
|
textual_hdrs = ["nuq-inl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":sfp",
|
":sfp",
|
||||||
|
":shared",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//hwy/contrib/sort:vqsort",
|
"@hwy//hwy/contrib/sort:vqsort",
|
||||||
],
|
],
|
||||||
|
|
@ -127,6 +138,7 @@ cc_test(
|
||||||
":distortion",
|
":distortion",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
|
":shared",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//:test_util",
|
"//:test_util",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
|
@ -137,11 +149,7 @@ cc_test(
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "compress",
|
name = "compress",
|
||||||
hdrs = [
|
hdrs = ["compress.h"],
|
||||||
"compress.h",
|
|
||||||
"nuq.h",
|
|
||||||
"sfp.h",
|
|
||||||
],
|
|
||||||
textual_hdrs = [
|
textual_hdrs = [
|
||||||
"compress-inl.h",
|
"compress-inl.h",
|
||||||
],
|
],
|
||||||
|
|
@ -151,12 +159,35 @@ cc_library(
|
||||||
":io",
|
":io",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
|
":shared",
|
||||||
|
"//:fp_arith",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:stats",
|
"@hwy//:stats",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "compress_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["compress_test.cc"],
|
||||||
|
features = ["fully_static_link"],
|
||||||
|
linkstatic = True,
|
||||||
|
local_defines = ["HWY_IS_TEST"],
|
||||||
|
# for test_suite.
|
||||||
|
tags = ["hwy_ops_test"],
|
||||||
|
deps = [
|
||||||
|
":compress",
|
||||||
|
":distortion",
|
||||||
|
"@googletest//:gtest_main",
|
||||||
|
"//:test_util",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:hwy_test_util",
|
||||||
|
"@hwy//:nanobenchmark",
|
||||||
|
"@hwy//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# For internal experimentation
|
# For internal experimentation
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "analyze",
|
name = "analyze",
|
||||||
|
|
@ -190,6 +221,7 @@ cc_binary(
|
||||||
srcs = ["compress_weights.cc"],
|
srcs = ["compress_weights.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":compress",
|
":compress",
|
||||||
|
":shared",
|
||||||
":weights_raw",
|
":weights_raw",
|
||||||
# Placeholder for internal dep, do not remove.,
|
# Placeholder for internal dep, do not remove.,
|
||||||
"//:args",
|
"//:args",
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,10 @@
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
#include "compression/io.h"
|
#include "compression/io.h"
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
#include "compression/sfp.h"
|
#include "compression/shared.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#if COMPRESS_STATS
|
#if COMPRESS_STATS
|
||||||
#include "hwy/stats.h"
|
#include "hwy/stats.h"
|
||||||
|
|
@ -41,29 +41,20 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
using BF16 = hwy::bfloat16_t;
|
|
||||||
|
|
||||||
static inline const char* TypeName(float) { return "f32"; }
|
static inline const char* TypeName(float) { return "f32"; }
|
||||||
static inline const char* TypeName(BF16) { return "b16"; }
|
static inline const char* TypeName(BF16) { return "b16"; }
|
||||||
|
static inline const char* TypeName(SfpStream) { return "SFP"; }
|
||||||
|
static inline const char* TypeName(NuqStream) { return "NUQ"; }
|
||||||
|
|
||||||
namespace detail {
|
// Returns the number of `MatT` elements required to store `capacity` values,
|
||||||
// How many MatT are required to store `capacity` weights. For all but
|
// which must not be zero.
|
||||||
// NuqStream, this is the same as `capacity`. For use by CompressedArray.
|
|
||||||
template <typename MatT>
|
template <typename MatT>
|
||||||
constexpr size_t CompressedArrayLen(size_t capacity) {
|
constexpr size_t CompressedArrayElements(size_t capacity) {
|
||||||
return capacity;
|
if constexpr (hwy::IsSame<hwy::RemoveCvRef<MatT>, NuqStream>()) {
|
||||||
}
|
return NuqStream::PackedEnd(capacity);
|
||||||
template <>
|
} else {
|
||||||
constexpr size_t CompressedArrayLen<NuqStream>(size_t capacity) {
|
return capacity;
|
||||||
return NuqStream::PackedEnd(capacity);
|
}
|
||||||
}
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
// Returns the number of bytes required to store a compressed array with the
|
|
||||||
// given type and capacity.
|
|
||||||
template <typename MatT>
|
|
||||||
constexpr size_t CompressedArraySize(size_t capacity) {
|
|
||||||
return detail::CompressedArrayLen<MatT>(capacity) * sizeof(MatT);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compressed representation of floating-point elements. The array length may
|
// Compressed representation of floating-point elements. The array length may
|
||||||
|
|
@ -71,10 +62,6 @@ constexpr size_t CompressedArraySize(size_t capacity) {
|
||||||
// implemented in SIMD code and are thus non-member functions.
|
// implemented in SIMD code and are thus non-member functions.
|
||||||
template <typename MatT, size_t kCapacity>
|
template <typename MatT, size_t kCapacity>
|
||||||
class CompressedArray {
|
class CompressedArray {
|
||||||
static constexpr size_t NumCompressed() {
|
|
||||||
return detail::CompressedArrayLen<MatT>(kCapacity);
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using value_type = MatT;
|
using value_type = MatT;
|
||||||
|
|
||||||
|
|
@ -100,11 +87,11 @@ class CompressedArray {
|
||||||
constexpr size_t size() const { return kCapacity; }
|
constexpr size_t size() const { return kCapacity; }
|
||||||
|
|
||||||
constexpr size_t CompressedSize() const {
|
constexpr size_t CompressedSize() const {
|
||||||
return NumCompressed() * sizeof(MatT);
|
return data_.size() * sizeof(MatT);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::array<MatT, NumCompressed()> data_;
|
std::array<MatT, CompressedArrayElements<MatT>(kCapacity)> data_;
|
||||||
// Blobs are at least kBlobAlign bytes anyway.
|
// Blobs are at least kBlobAlign bytes anyway.
|
||||||
float scale_[kBlobAlign / sizeof(float)];
|
float scale_[kBlobAlign / sizeof(float)];
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
// Copyright 2023 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
@ -21,9 +21,9 @@
|
||||||
#define HWY_TARGET_INCLUDE \
|
#define HWY_TARGET_INCLUDE \
|
||||||
"compression/compress_weights.cc" // NOLINT
|
"compression/compress_weights.cc" // NOLINT
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// Must come after foreach_target.h to avoid redefinition errors.
|
|
||||||
#include "compression/compress-inl.h"
|
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
|
|
||||||
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||||
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||||
|
|
@ -38,9 +38,11 @@
|
||||||
#include <thread> // NOLINT
|
#include <thread> // NOLINT
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
#include "compression/shared.h"
|
||||||
#include "compression/weights_raw.h"
|
#include "compression/weights_raw.h"
|
||||||
#include "gemma/common.h" // Model
|
#include "gemma/common.h" // Model
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -57,11 +59,10 @@ float ScaleWeights(float* data, size_t len) {
|
||||||
for (size_t i = 0; i < len; ++i) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
maxabs = std::max(maxabs, std::abs(data[i]));
|
maxabs = std::max(maxabs, std::abs(data[i]));
|
||||||
}
|
}
|
||||||
const float kMaxRange = 1.875f;
|
if (maxabs <= kMaxSFP) {
|
||||||
if (maxabs <= kMaxRange) {
|
|
||||||
return 1.0f;
|
return 1.0f;
|
||||||
}
|
}
|
||||||
const float scale = maxabs / kMaxRange;
|
const float scale = maxabs / kMaxSFP;
|
||||||
const float inv_scale = 1.0f / scale;
|
const float inv_scale = 1.0f / scale;
|
||||||
for (size_t i = 0; i < len; ++i) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
data[i] *= inv_scale;
|
data[i] *= inv_scale;
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,12 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
|
// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
|
||||||
// despite floating-point rounding. `sum` is already the best estimate, so do
|
// despite floating-point rounding. `sum` is already the best estimate for the
|
||||||
// not actually add `err` to it. Knuth98/Moller65. Unlike Fast2Sum [Dekker71],
|
// addition, so do not directly add `err` to it.
|
||||||
// this does not require any relative ordering of the exponents of a and b.
|
//
|
||||||
|
// Knuth98/Moller65. Unlike FastTwoSum, this does not require any relative
|
||||||
|
// ordering of the exponents of a and b. 6 ops.
|
||||||
|
// TODO: move to and use in Highway stats.h?
|
||||||
template <typename T, HWY_IF_FLOAT3264(T)>
|
template <typename T, HWY_IF_FLOAT3264(T)>
|
||||||
static inline T TwoSum(T a, T b, T& err) {
|
static inline T TwoSum(T a, T b, T& err) {
|
||||||
const T sum = a + b;
|
const T sum = a + b;
|
||||||
|
|
@ -88,7 +91,6 @@ class DistortionStats {
|
||||||
const float l1f = hwy::ScalarAbs(original - distorted);
|
const float l1f = hwy::ScalarAbs(original - distorted);
|
||||||
const double l1 = static_cast<double>(l1f);
|
const double l1 = static_cast<double>(l1f);
|
||||||
s_l1_.Notify(l1f);
|
s_l1_.Notify(l1f);
|
||||||
b_l1_.Notify(HWY_MIN(99, static_cast<int>(l1f * 1E4)));
|
|
||||||
if (l1f != 0.0f) {
|
if (l1f != 0.0f) {
|
||||||
l1_.push_back(l1f);
|
l1_.push_back(l1f);
|
||||||
}
|
}
|
||||||
|
|
@ -102,7 +104,7 @@ class DistortionStats {
|
||||||
// as much as an actual sign flip, so do not count them.
|
// as much as an actual sign flip, so do not count them.
|
||||||
n_sign_flip_ +=
|
n_sign_flip_ +=
|
||||||
((original < 0.0f) != (distorted < 0.0f)) && !rounded_to_zero;
|
((original < 0.0f) != (distorted < 0.0f)) && !rounded_to_zero;
|
||||||
n_exact_ += (l1f == 0.0f);
|
n_exact_ += (original == distorted);
|
||||||
n_rounded_to_zero += rounded_to_zero;
|
n_rounded_to_zero += rounded_to_zero;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -122,7 +124,6 @@ class DistortionStats {
|
||||||
void Assimilate(const DistortionStats& other) {
|
void Assimilate(const DistortionStats& other) {
|
||||||
s_original_.Assimilate(other.s_original_);
|
s_original_.Assimilate(other.s_original_);
|
||||||
s_l1_.Assimilate(other.s_l1_);
|
s_l1_.Assimilate(other.s_l1_);
|
||||||
b_l1_.Assimilate(other.b_l1_);
|
|
||||||
sum_l1_.Assimilate(other.sum_l1_);
|
sum_l1_.Assimilate(other.sum_l1_);
|
||||||
sum_l1_rounded_.Assimilate(other.sum_l1_rounded_);
|
sum_l1_rounded_.Assimilate(other.sum_l1_rounded_);
|
||||||
l1_.insert(l1_.end(), other.l1_.begin(), other.l1_.end());
|
l1_.insert(l1_.end(), other.l1_.begin(), other.l1_.end());
|
||||||
|
|
@ -204,7 +205,6 @@ class DistortionStats {
|
||||||
private:
|
private:
|
||||||
hwy::Stats s_original_;
|
hwy::Stats s_original_;
|
||||||
hwy::Stats s_l1_;
|
hwy::Stats s_l1_;
|
||||||
hwy::Bins<100> b_l1_;
|
|
||||||
CascadedSummation<double> sum_l1_; // all
|
CascadedSummation<double> sum_l1_; // all
|
||||||
CascadedSummation<double> sum_l1_rounded_; // only if rounded_to_zero
|
CascadedSummation<double> sum_l1_rounded_; // only if rounded_to_zero
|
||||||
std::vector<float> l1_;
|
std::vector<float> l1_;
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "compression/shared.h"
|
||||||
#include "util/test_util.h"
|
#include "util/test_util.h"
|
||||||
#include "hwy/nanobenchmark.h"
|
#include "hwy/nanobenchmark.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
@ -74,13 +75,13 @@ TEST(DistortionTest, TestDilution) {
|
||||||
HWY_ASSERT(IsNear(0.001, stats.WeightedAverageL1()));
|
HWY_ASSERT(IsNear(0.001, stats.WeightedAverageL1()));
|
||||||
|
|
||||||
// Now add a large difference:
|
// Now add a large difference:
|
||||||
stats.Notify(1.875f - 0.0625f, 1.875f); // max magnitude, 3-bit mantissa
|
stats.Notify(kMaxSFP - 0.0625f, kMaxSFP); // max magnitude, 3-bit mantissa
|
||||||
// .. WeightedAverageL1 is closer to it.
|
// .. WeightedAverageL1 is closer to it.
|
||||||
HWY_ASSERT(IsInside(0.020, 0.025, stats.WeightedAverageL1()));
|
HWY_ASSERT(IsInside(0.020, 0.025, stats.WeightedAverageL1()));
|
||||||
|
|
||||||
// Add a small and large difference:
|
// Add a small and large difference:
|
||||||
stats.Notify((1.75f - 0.125f) / 1024, 1.75f / 1024); // small, 2-bit mantissa
|
stats.Notify((1.75f - 0.125f) / 1024, 1.75f / 1024); // small, 2-bit mantissa
|
||||||
stats.Notify(-1.875f + 0.0625f, -1.875f); // larger negative
|
stats.Notify(-kMaxSFP + 0.0625f, -kMaxSFP); // larger negative
|
||||||
// .. SNR is still barely affected.
|
// .. SNR is still barely affected.
|
||||||
HWY_ASSERT(IsInside(890.0, 900.0, stats.GeomeanValueDivL1()));
|
HWY_ASSERT(IsInside(890.0, 900.0, stats.GeomeanValueDivL1()));
|
||||||
// .. WeightedAverageL1 is higher after another large error.
|
// .. WeightedAverageL1 is higher after another large error.
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
#include "compression/sfp.h"
|
#include "compression/shared.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
|
||||||
|
|
|
||||||
|
|
@ -40,9 +40,8 @@ static constexpr size_t kGroupSize = 256;
|
||||||
// Points to the *start* of a NUQ stream. Aligning the allocation (see
|
// Points to the *start* of a NUQ stream. Aligning the allocation (see
|
||||||
// aligned_allocator.h) may be speed up decoding but is not required.
|
// aligned_allocator.h) may be speed up decoding but is not required.
|
||||||
//
|
//
|
||||||
// See go/streaming-weight-decode for background and design. Layout: first one
|
// Layout: first one table of kClusters entries per group, in ascending order
|
||||||
// table of kClusters entries per group, in ascending order of group index,
|
// of group index, then two packed indices per byte.
|
||||||
// then two packed indices per byte.
|
|
||||||
//
|
//
|
||||||
// Indices are stored in-order to enable vector-length agnostic decode, because
|
// Indices are stored in-order to enable vector-length agnostic decode, because
|
||||||
// streams may be persisted to disk and used by other CPUs.
|
// streams may be persisted to disk and used by other CPUs.
|
||||||
|
|
@ -54,26 +53,23 @@ static constexpr size_t kGroupSize = 256;
|
||||||
#pragma pack(push, 1)
|
#pragma pack(push, 1)
|
||||||
struct NuqStream {
|
struct NuqStream {
|
||||||
// Returns offset of packed indices from the start of the stream. This matches
|
// Returns offset of packed indices from the start of the stream. This matches
|
||||||
// the (padded) total table size because table entries are bytes. `capacity`
|
// the (padded) total table size because table entries are bytes.
|
||||||
// is already a multiple of `kGroupSize`.
|
|
||||||
static constexpr size_t PackedStart(size_t capacity) {
|
static constexpr size_t PackedStart(size_t capacity) {
|
||||||
// Round up to avoid cache-line splits when loading indices. No effect on
|
// Round up to avoid cache-line splits when loading indices. No effect on
|
||||||
// size as long as capacity / kGroupSize is a multiple of 4.
|
// size as long as capacity / kGroupSize is a multiple of 4.
|
||||||
return hwy::RoundUpTo((capacity / kGroupSize) * kClusters, 64);
|
return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns number of NuqStream to allocate for the stream, which matches its
|
// Returns number of NuqStream to allocate for the stream, which matches its
|
||||||
// size in bytes. `capacity` is already a multiple of `kGroupSize`.
|
// size in bytes. `capacity` is already a multiple of `kGroupSize`.
|
||||||
static constexpr size_t PackedEnd(size_t capacity) {
|
static constexpr size_t PackedEnd(size_t capacity) {
|
||||||
return PackedStart(capacity) + capacity / 2; // two 4-bit indices per byte.
|
return PackedStart(capacity) + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t byte;
|
uint8_t byte;
|
||||||
};
|
};
|
||||||
#pragma pack(pop)
|
#pragma pack(pop)
|
||||||
|
|
||||||
static inline const char* TypeName(NuqStream) { return "NUQ"; }
|
|
||||||
|
|
||||||
// Storage for dynamic programming. There are two matrices; we use separate
|
// Storage for dynamic programming. There are two matrices; we use separate
|
||||||
// allocations to avoid type punning.
|
// allocations to avoid type punning.
|
||||||
template <class T>
|
template <class T>
|
||||||
|
|
@ -101,7 +97,7 @@ struct ClusterBuf {
|
||||||
num = new_num;
|
num = new_num;
|
||||||
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
|
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
|
||||||
centers = hwy::AllocateAligned<float>(num_groups * kClusters);
|
centers = hwy::AllocateAligned<float>(num_groups * kClusters);
|
||||||
idx = hwy::AllocateAligned<uint16_t>(num);
|
idx = hwy::AllocateAligned<uint16_t>(hwy::RoundUpTo(num, kGroupSize));
|
||||||
}
|
}
|
||||||
|
|
||||||
AlignedMatrix<float> d;
|
AlignedMatrix<float> d;
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
SbsWriterImpl() : pool_(0), compressor_(pool_) {}
|
SbsWriterImpl() : pool_(0), compressor_(pool_) {}
|
||||||
|
|
||||||
void Insert(std::string name, absl::Span<const float> weights) override {
|
void Insert(std::string name, absl::Span<const float> weights) override {
|
||||||
const size_t out_size = CompressedArraySize<SfpStream>(weights.size());
|
const size_t out_size = CompressedArrayElements<SfpStream>(weights.size());
|
||||||
sfp_streams_.push_back(std::vector<SfpStream>(out_size));
|
sfp_streams_.push_back(std::vector<SfpStream>(out_size));
|
||||||
compressor_.Insert<SfpStream>(name.data(), weights.data(), weights.size(),
|
compressor_.Insert<SfpStream>(name.data(), weights.data(), weights.size(),
|
||||||
working_set_, out_size,
|
working_set_, out_size,
|
||||||
|
|
@ -54,7 +54,7 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
void InsertNUQ(std::string name, absl::Span<const float> weights) override {
|
void InsertNUQ(std::string name, absl::Span<const float> weights) override {
|
||||||
const size_t out_size = CompressedArraySize<NuqStream>(weights.size());
|
const size_t out_size = CompressedArrayElements<NuqStream>(weights.size());
|
||||||
nuq_streams_.push_back(std::vector<NuqStream>(out_size));
|
nuq_streams_.push_back(std::vector<NuqStream>(out_size));
|
||||||
compressor_.Insert<NuqStream>(name.data(), weights.data(), weights.size(),
|
compressor_.Insert<NuqStream>(name.data(), weights.data(), weights.size(),
|
||||||
working_set_, out_size,
|
working_set_, out_size,
|
||||||
|
|
@ -64,7 +64,7 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
void InsertBfloat16(std::string name,
|
void InsertBfloat16(std::string name,
|
||||||
absl::Span<const float> weights) override {
|
absl::Span<const float> weights) override {
|
||||||
const size_t out_size =
|
const size_t out_size =
|
||||||
CompressedArraySize<hwy::bfloat16_t>(weights.size());
|
CompressedArrayElements<hwy::bfloat16_t>(weights.size());
|
||||||
bf16_streams_.push_back(std::vector<hwy::bfloat16_t>(out_size));
|
bf16_streams_.push_back(std::vector<hwy::bfloat16_t>(out_size));
|
||||||
compressor_.Insert<hwy::bfloat16_t>(name.data(), weights.data(),
|
compressor_.Insert<hwy::bfloat16_t>(name.data(), weights.data(),
|
||||||
weights.size(), working_set_, out_size,
|
weights.size(), working_set_, out_size,
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/shared.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
|
||||||
|
|
|
||||||
|
|
@ -1,51 +0,0 @@
|
||||||
// Copyright 2023 Google LLC
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
|
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
|
|
||||||
|
|
||||||
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
|
|
||||||
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
|
|
||||||
// It supports seeking at a granularity of 1, decoding to bf16/f32, and a
|
|
||||||
// fused decode/dot product with bf16/f32 vectors.
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
namespace gcpp {
|
|
||||||
|
|
||||||
// Points to the *start* of an SFP stream. Values are stored in-order to enable
|
|
||||||
// vector-length agnostic seeking, because streams may be written to disk for
|
|
||||||
// loading on other CPUs.
|
|
||||||
//
|
|
||||||
// Characteristics:
|
|
||||||
// - 24-bit dynamic range, with max exponent 2^0.
|
|
||||||
// - 3 bit mantissa for values >= 2^-7, otherwise 2.
|
|
||||||
//
|
|
||||||
// This is faster to decode than a straightforward implementation of eXmY, in
|
|
||||||
// part because SFP does not require subnormals. Unlike OCP MX, it also does not
|
|
||||||
// require side information (shared exponents).
|
|
||||||
//
|
|
||||||
// Although the representation could probably be shrunk to 6-7 bits, more
|
|
||||||
// savings can be had by non-uniform clustering - see nuq.h.
|
|
||||||
#pragma pack(push, 1)
|
|
||||||
struct SfpStream {
|
|
||||||
uint8_t byte;
|
|
||||||
};
|
|
||||||
#pragma pack(pop)
|
|
||||||
|
|
||||||
static inline const char* TypeName(SfpStream) { return "SFP"; }
|
|
||||||
|
|
||||||
} // namespace gcpp
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_
|
|
||||||
|
|
@ -18,8 +18,6 @@
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "compression/sfp.h"
|
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
@ -27,20 +25,21 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
|
#include "compression/shared.h"
|
||||||
#include "util/test_util.h"
|
#include "util/test_util.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// Any highway.h must come after foreach_target.h
|
#include "hwy/highway.h"
|
||||||
|
// After highway.h
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
#include "ops/dot-inl.h"
|
#include "ops/dot-inl.h"
|
||||||
#include "hwy/highway.h"
|
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
// Copyright 2023 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Definitions shared between the public compress-inl.h interface and the
|
||||||
|
// sfp-inl.h and nuq-inl.h implementation details.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
using BF16 = hwy::bfloat16_t;
|
||||||
|
|
||||||
|
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
|
||||||
|
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
|
||||||
|
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
|
||||||
|
//
|
||||||
|
// Characteristics:
|
||||||
|
// - 24-bit dynamic range, with max exponent 2^0.
|
||||||
|
// - 3 bit mantissa for values >= 2^-7, otherwise 2.
|
||||||
|
//
|
||||||
|
// A pointer to this is the *start* of an SFP stream. Values are stored
|
||||||
|
// in-order to enable vector-length agnostic seeking, because streams may be
|
||||||
|
// written to disk for loading on other CPUs.
|
||||||
|
//
|
||||||
|
// This is faster to decode than a straightforward implementation of eXmY, in
|
||||||
|
// part because SFP does not require subnormals. Unlike OCP MX, it also does not
|
||||||
|
// require side information (shared exponents).
|
||||||
|
//
|
||||||
|
// Although the representation could probably be shrunk to 6-7 bits, more
|
||||||
|
// savings can be had by non-uniform clustering - see nuq.h.
|
||||||
|
#pragma pack(push, 1)
|
||||||
|
struct SfpStream {
|
||||||
|
uint8_t byte;
|
||||||
|
};
|
||||||
|
#pragma pack(pop)
|
||||||
|
|
||||||
|
// Largest possible input magnitude: 1.111 * 2^0. This could be increased by
|
||||||
|
// shifting the value range (exponent bias).
|
||||||
|
constexpr float kMaxSFP = 1.875f;
|
||||||
|
|
||||||
|
// Non-owning view of packed elements. Shortens argument lists.
|
||||||
|
//
|
||||||
|
// Callers typically also pass an `ofs` starting offset. This is not folded
|
||||||
|
// into `ptr` because NUQ consists of two separate streams. To discourage direct
|
||||||
|
// use of `ptr` without that offset, we define a separate class instead of
|
||||||
|
// reusing `hwy::Span`.
|
||||||
|
template <typename Packed>
|
||||||
|
struct PackedSpan {
|
||||||
|
void BoundsCheck(size_t packed_ofs, size_t num) const {
|
||||||
|
HWY_DASSERT(packed_ofs + num <= size);
|
||||||
|
(void)size;
|
||||||
|
}
|
||||||
|
|
||||||
|
Packed* HWY_RESTRICT ptr;
|
||||||
|
size_t size; // for BoundsCheck and nuq-inl.h HWY_ASSERT.
|
||||||
|
};
|
||||||
|
|
||||||
|
// Avoids spelling out the template parameter in every call.
|
||||||
|
template <typename Packed>
|
||||||
|
HWY_INLINE PackedSpan<Packed> MakeSpan(Packed* ptr, size_t size) {
|
||||||
|
return {ptr, size};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packed>
|
||||||
|
HWY_INLINE PackedSpan<const Packed> MakeConstSpan(Packed* ptr, size_t size) {
|
||||||
|
return {ptr, size};
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Implicit" conversion from a PackedSpan<T> to PackedSpan<const T>, used in
|
||||||
|
// `RMSNormInplace` and compression tests.
|
||||||
|
template <typename Packed>
|
||||||
|
HWY_INLINE PackedSpan<const Packed> MakeConst(PackedSpan<Packed> packed) {
|
||||||
|
return {packed.ptr, packed.size};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
||||||
|
|
@ -55,11 +55,6 @@ struct Activations {
|
||||||
// Rope
|
// Rope
|
||||||
RowVectorBatch<float> inv_timescale;
|
RowVectorBatch<float> inv_timescale;
|
||||||
|
|
||||||
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
|
||||||
// per-thread storage.
|
|
||||||
// TODO: remove once MatVec is no longer used.
|
|
||||||
RowVectorBatch<float> even_odd;
|
|
||||||
|
|
||||||
MatMulEnv env;
|
MatMulEnv env;
|
||||||
|
|
||||||
// Multi-Head Attention?
|
// Multi-Head Attention?
|
||||||
|
|
@ -123,9 +118,6 @@ struct Activations {
|
||||||
|
|
||||||
inv_timescale = CreateInvTimescale<TConfig>();
|
inv_timescale = CreateInvTimescale<TConfig>();
|
||||||
|
|
||||||
const size_t num_lp = pools.NumLP();
|
|
||||||
even_odd = RowVectorBatch<float>(1, kModelDim * num_lp);
|
|
||||||
|
|
||||||
env = MatMulEnv(pools);
|
env = MatMulEnv(pools);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -192,8 +192,7 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
float* out_ptr = activations.att_sums.Batch(batch_idx);
|
float* out_ptr = activations.att_sums.Batch(batch_idx);
|
||||||
MatVecAdd<kModelDim, kModelDim>(
|
MatVecAdd<kModelDim, kModelDim>(
|
||||||
layer_weights->griffin.linear_out_w, 0, x,
|
layer_weights->griffin.linear_out_w, 0, x,
|
||||||
layer_weights->griffin.linear_out_biases.data_scale1(),
|
layer_weights->griffin.linear_out_biases.data_scale1(), out_ptr, pool);
|
||||||
activations.even_odd.All(), out_ptr, pool);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -283,8 +282,8 @@ class GemmaAttention {
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||||
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
||||||
layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, kv,
|
||||||
activations_.even_odd.All(), kv, pool_);
|
pool_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,152 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
// Building blocks for floating-point arithmetic.
|
||||||
|
|
||||||
|
// Include guard for (potentially) SIMD code.
|
||||||
|
#if defined(THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||||
|
#ifdef THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE
|
||||||
|
#undef THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE
|
||||||
|
#else
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
// Exact multiplication
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
// Returns non-overlapping `x` and `y` such that `x + y` = `f` and |x| >= |y|.
|
||||||
|
// Notation from Algorithm 3.1 in Handbook of Floating-Point Arithmetic. 4 ops.
|
||||||
|
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
|
static HWY_INLINE void VeltkampSplit(DF df, VF a, VF& x, VF& y) {
|
||||||
|
using TF = hn::TFromD<DF>;
|
||||||
|
constexpr int t = hwy::MantissaBits<TF>() + 1; // = -log2(epsilon)
|
||||||
|
constexpr int s = hwy::DivCeil(t, 2);
|
||||||
|
const VF factor = hn::Set(df, hwy::ConvertScalarTo<TF>((1ULL << s) + 1));
|
||||||
|
const VF c = hn::Mul(factor, a);
|
||||||
|
x = hn::Sub(c, hn::Sub(c, a));
|
||||||
|
y = hn::Sub(a, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
// Returns `prod` and `err` such that `prod + err` is exactly equal to `a * b`,
|
||||||
|
// despite floating-point rounding, assuming that `err` is not subnormal, i.e.,
|
||||||
|
// the sum of exponents >= min exponent + mantissa bits. 2..17 ops. Useful for
|
||||||
|
// compensated dot products and polynomial evaluation.
|
||||||
|
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
|
static HWY_INLINE VF TwoProducts(DF df, VF a, VF b, VF& err) {
|
||||||
|
const VF prod = hn::Mul(a, b);
|
||||||
|
if constexpr (HWY_NATIVE_FMA) {
|
||||||
|
err = hn::MulSub(a, b, prod);
|
||||||
|
} else {
|
||||||
|
// Non-FMA fallback: we assume these calculations do not overflow.
|
||||||
|
VF a1, a2, b1, b2;
|
||||||
|
detail::VeltkampSplit(df, a, a1, a2);
|
||||||
|
detail::VeltkampSplit(df, b, b1, b2);
|
||||||
|
const VF m = hn::Sub(prod, hn::Mul(a1, b1));
|
||||||
|
const VF n = hn::Sub(m, hn::Mul(a2, b1));
|
||||||
|
const VF o = hn::Sub(n, hn::Mul(a1, b2));
|
||||||
|
err = hn::Sub(hn::Mul(a2, b2), o);
|
||||||
|
}
|
||||||
|
return prod;
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
// Exact addition
|
||||||
|
|
||||||
|
// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
|
||||||
|
// despite floating-point rounding. `sum` is already the best estimate for the
|
||||||
|
// addition, so do not directly add `err` to it. `UpdateCascadedSums` instead
|
||||||
|
// accumulates multiple `err`, which are then later added to the total `sum`.
|
||||||
|
//
|
||||||
|
// Knuth98/Moller65. Unlike FastTwoSums, this does not require any relative
|
||||||
|
// ordering of the exponents of a and b. 6 ops.
|
||||||
|
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
|
static HWY_INLINE VF TwoSums(DF /*df*/, VF a, VF b, VF& err) {
|
||||||
|
const VF sum = hn::Add(a, b);
|
||||||
|
const VF a2 = hn::Sub(sum, b);
|
||||||
|
const VF b2 = hn::Sub(sum, a2);
|
||||||
|
const VF err_a = hn::Sub(a, a2);
|
||||||
|
const VF err_b = hn::Sub(b, b2);
|
||||||
|
err = hn::Add(err_a, err_b);
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// As above, but only exact if the exponent of `a` >= that of `b`, which is the
|
||||||
|
// case if |a| >= |b|. Dekker71, also used in Kahan65 compensated sum. 3 ops.
|
||||||
|
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
|
static HWY_INLINE VF FastTwoSums(DF /*df*/, VF a, VF b, VF& err) {
|
||||||
|
const VF sum = hn::Add(a, b);
|
||||||
|
const VF b2 = hn::Sub(sum, a);
|
||||||
|
err = hn::Sub(b, b2);
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
// Cascaded summation (twice working precision)
|
||||||
|
|
||||||
|
// Accumulates numbers with about twice the precision of T using 7 * n FLOPS.
|
||||||
|
// Rump/Ogita/Oishi08, Algorithm 6.11 in Handbook of Floating-Point Arithmetic.
|
||||||
|
//
|
||||||
|
// Because vectors generally cannot be wrapped in a class, we use functions.
|
||||||
|
// `sum` and `sum_err` must be initially zero. Each lane is an independent sum.
|
||||||
|
// To reduce them into a single result, use `ReduceCascadedSum`.
|
||||||
|
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
|
void UpdateCascadedSums(DF df, VF v, VF& sum, VF& sum_err) {
|
||||||
|
VF err;
|
||||||
|
sum = TwoSums(df, sum, v, err);
|
||||||
|
sum_err += err;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combines two cascaded sum vectors, typically from unrolling/parallelization.
|
||||||
|
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
|
void AssimilateCascadedSums(DF df, const VF& other_sum, const VF& other_sum_err,
|
||||||
|
VF& sum, VF& sum_err) {
|
||||||
|
UpdateCascadedSums(df, other_sum, sum, sum_err);
|
||||||
|
sum_err += other_sum_err;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduces cascaded sums, to a single value. Slow, call outside of loops.
|
||||||
|
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
|
hn::TFromD<DF> ReduceCascadedSums(DF df, const VF sum, VF sum_err) {
|
||||||
|
const size_t N = hn::Lanes(df);
|
||||||
|
using TF = hn::TFromD<DF>;
|
||||||
|
TF total = TF{0.0};
|
||||||
|
TF total_err = TF{0.0};
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
TF err;
|
||||||
|
total = TwoSum(total, hn::ExtractLane(sum, i), err);
|
||||||
|
total_err += err;
|
||||||
|
}
|
||||||
|
return total + total_err + hn::ReduceSum(df, sum_err);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#endif // NOLINT
|
||||||
|
|
@ -115,15 +115,13 @@ void TestMatVecAdd() {
|
||||||
GenerateMat<float, kOuter, kInner>(0, pool);
|
GenerateMat<float, kOuter, kInner>(0, pool);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> even_odd =
|
|
||||||
hwy::AllocateAligned<float>(kInner * pool.NumWorkers());
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
||||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
||||||
hwy::AllocateAligned<float>(kOuter);
|
hwy::AllocateAligned<float>(kOuter);
|
||||||
HWY_ASSERT(vec && add && even_odd && expected_out && actual_out);
|
HWY_ASSERT(vec && add && expected_out && actual_out);
|
||||||
MatVecAdd<kOuter, kInner>(mat, 0, vec.get(), add.get(), even_odd.get(),
|
MatVecAdd<kOuter, kInner>(mat, 0, vec.get(), add.get(), actual_out.get(),
|
||||||
actual_out.get(), pool);
|
pool);
|
||||||
AssertClose<kOuter>(actual_out, expected_out);
|
AssertClose<kOuter>(actual_out, expected_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
168
ops/matvec-inl.h
168
ops/matvec-inl.h
|
|
@ -21,8 +21,6 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "compression/sfp.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
@ -48,37 +46,6 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
|
||||||
const size_t size, float* HWY_RESTRICT out) {
|
|
||||||
const hn::ScalableTag<float> df;
|
|
||||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf16;
|
|
||||||
|
|
||||||
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
|
|
||||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
|
||||||
|
|
||||||
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
|
|
||||||
const auto interleaved = hn::LoadU(dbf16, vec_aligned + i);
|
|
||||||
hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i);
|
|
||||||
hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_INLINE void ToEvenOddF32(const float* HWY_RESTRICT vec_aligned,
|
|
||||||
const size_t size, float* HWY_RESTRICT out) {
|
|
||||||
const hn::ScalableTag<float> df;
|
|
||||||
using VF = hn::Vec<decltype(df)>;
|
|
||||||
|
|
||||||
HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0);
|
|
||||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
|
||||||
|
|
||||||
VF vec0, vec1;
|
|
||||||
for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) {
|
|
||||||
hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1);
|
|
||||||
hn::Store(vec0, df, out + i);
|
|
||||||
hn::Store(vec1, df, out + i + hn::Lanes(df));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simple version without tiling nor threading, but two offsets/outputs and
|
// Simple version without tiling nor threading, but two offsets/outputs and
|
||||||
// always with addition.
|
// always with addition.
|
||||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT,
|
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT,
|
||||||
|
|
@ -91,16 +58,15 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
||||||
float* HWY_RESTRICT out0,
|
float* HWY_RESTRICT out0,
|
||||||
float* HWY_RESTRICT out1) {
|
float* HWY_RESTRICT out1) {
|
||||||
PROFILER_ZONE("TwoOfsMatVecAddLoop");
|
PROFILER_ZONE("TwoOfsMatVecAddLoop");
|
||||||
constexpr bool kVecEO = false;
|
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
||||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||||
const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner;
|
const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner;
|
||||||
const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner;
|
const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner;
|
||||||
out0[idx_row] = hwy::ConvertScalarTo<float>(add0[idx_row]) +
|
out0[idx_row] = hwy::ConvertScalarTo<float>(add0[idx_row]) +
|
||||||
Dot<kVecEO>(df, mat, row_ofs0, vec_aligned, kInner);
|
Dot<false>(df, mat, row_ofs0, vec_aligned, kInner);
|
||||||
out1[idx_row] = hwy::ConvertScalarTo<float>(add1[idx_row]) +
|
out1[idx_row] = hwy::ConvertScalarTo<float>(add1[idx_row]) +
|
||||||
Dot<kVecEO>(df, mat, row_ofs1, vec_aligned, kInner);
|
Dot<false>(df, mat, row_ofs1, vec_aligned, kInner);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -125,7 +91,7 @@ namespace detail {
|
||||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||||
// of row i with `vec_aligned` and add into `out[i]`. The upper-left
|
// of row i with `vec_aligned` and add into `out[i]`. The upper-left
|
||||||
// coordinate of the tile is r0, c0.
|
// coordinate of the tile is r0, c0.
|
||||||
template <bool kVecEO, class DF, typename ArrayT, typename VecT>
|
template <class DF, typename ArrayT, typename VecT>
|
||||||
HWY_INLINE void AccumulatePartialDotProducts(
|
HWY_INLINE void AccumulatePartialDotProducts(
|
||||||
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
|
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
|
||||||
size_t c0, size_t num_rows, size_t num_cols,
|
size_t c0, size_t num_rows, size_t num_cols,
|
||||||
|
|
@ -133,15 +99,14 @@ HWY_INLINE void AccumulatePartialDotProducts(
|
||||||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||||
out[idx_row] +=
|
out[idx_row] +=
|
||||||
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
Dot<false>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial
|
// Same as AccumulatePartialDotProducts, but sets out[i] to the first partial
|
||||||
// dot product + init (if kInit), which avoids having to zero-initialize and
|
// dot product + init (if kInit), which avoids having to zero-initialize and
|
||||||
// accumulate.
|
// accumulate.
|
||||||
template <bool kVecEO, bool kInit, class DF, typename ArrayT, typename VecT,
|
template <bool kInit, class DF, typename ArrayT, typename VecT, typename InitT>
|
||||||
typename InitT>
|
|
||||||
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
||||||
size_t mat_ofs, size_t mat_stride,
|
size_t mat_ofs, size_t mat_stride,
|
||||||
size_t r0, size_t c0,
|
size_t r0, size_t c0,
|
||||||
|
|
@ -154,10 +119,10 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
||||||
if constexpr (kInit) {
|
if constexpr (kInit) {
|
||||||
out[idx_row] =
|
out[idx_row] =
|
||||||
hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
|
hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
|
||||||
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
Dot<false>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||||
} else {
|
} else {
|
||||||
out[idx_row] =
|
out[idx_row] =
|
||||||
Dot<kVecEO>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
Dot<false>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -166,8 +131,7 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
||||||
// horizontal strip of the entire matrix); the result is the full dot product
|
// horizontal strip of the entire matrix); the result is the full dot product
|
||||||
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we
|
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we
|
||||||
// store into in out[r - r0].
|
// store into in out[r - r0].
|
||||||
template <bool kVecEO, bool kAdd, class DF, typename ArrayT, typename VecT,
|
template <bool kAdd, class DF, typename ArrayT, typename VecT, typename AddT>
|
||||||
typename AddT>
|
|
||||||
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
||||||
size_t mat_ofs, size_t mat_stride,
|
size_t mat_ofs, size_t mat_stride,
|
||||||
size_t r0, size_t num_rows,
|
size_t r0, size_t num_rows,
|
||||||
|
|
@ -176,56 +140,25 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
||||||
float* HWY_RESTRICT out) {
|
float* HWY_RESTRICT out) {
|
||||||
// Tall and skinny: set `out` to the single dot product.
|
// Tall and skinny: set `out` to the single dot product.
|
||||||
if (mat_stride < MaxCols()) {
|
if (mat_stride < MaxCols()) {
|
||||||
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0,
|
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
||||||
0, num_rows, mat_stride,
|
num_rows, mat_stride, vec_aligned, add,
|
||||||
vec_aligned, add, out);
|
out);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have at least MaxCols, so start by setting `out` to that:
|
// We have at least MaxCols, so start by setting `out` to that:
|
||||||
SetFirstPartialDotProducts<kVecEO, kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
SetFirstPartialDotProducts<kAdd>(df, mat, mat_ofs, mat_stride, r0, 0,
|
||||||
num_rows, MaxCols(), vec_aligned,
|
num_rows, MaxCols(), vec_aligned, add, out);
|
||||||
add, out);
|
|
||||||
// For further multiples of MaxCols, accumulate. Remainders handled below.
|
// For further multiples of MaxCols, accumulate. Remainders handled below.
|
||||||
size_t c0 = MaxCols();
|
size_t c0 = MaxCols();
|
||||||
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
|
for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) {
|
||||||
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
|
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
||||||
num_rows, MaxCols(), vec_aligned, out);
|
MaxCols(), vec_aligned, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (c0 < mat_stride) { // Final cols
|
if (c0 < mat_stride) { // Final cols
|
||||||
AccumulatePartialDotProducts<kVecEO>(df, mat, mat_ofs, mat_stride, r0, c0,
|
AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows,
|
||||||
num_rows, mat_stride - c0, vec_aligned,
|
mat_stride - c0, vec_aligned, out);
|
||||||
out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool kVecIsEvenOdd, bool kAdd, size_t kOuter, size_t kInner,
|
|
||||||
typename ArrayT, typename VecT, typename AddT>
|
|
||||||
HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs,
|
|
||||||
const VecT* HWY_RESTRICT const vec_aligned,
|
|
||||||
const AddT* HWY_RESTRICT const add,
|
|
||||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
|
||||||
const hn::ScalableTag<float> df;
|
|
||||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
|
||||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
|
||||||
|
|
||||||
// For each entire strip.
|
|
||||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
|
||||||
PROFILER_ZONE("MatVec.lambda");
|
|
||||||
const size_t r0 = strip * kRowsPerStrip;
|
|
||||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
|
||||||
df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add,
|
|
||||||
out + r0);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Remaining rows
|
|
||||||
const size_t r0 = kNumStrips * kRowsPerStrip;
|
|
||||||
if (r0 < kOuter) {
|
|
||||||
PROFILER_ZONE("MatVec remainder");
|
|
||||||
const size_t num_rows = kOuter - r0;
|
|
||||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
|
||||||
df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -238,28 +171,30 @@ template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||||
HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const VecT* HWY_RESTRICT const vec_aligned,
|
const VecT* HWY_RESTRICT const vec_aligned,
|
||||||
const AddT* HWY_RESTRICT const add,
|
const AddT* HWY_RESTRICT const add,
|
||||||
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
|
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
PROFILER_ZONE("MatVecAdd");
|
PROFILER_ZONE("MatVecAdd");
|
||||||
|
|
||||||
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
|
const hn::ScalableTag<float> df;
|
||||||
using MatT = typename ArrayT::value_type;
|
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||||
// Sfp -> float does not benefit enough to recoup the cost of ToEvenOddF32.
|
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||||
if constexpr (CompressTraits<MatT>::kSupportsEvenOdd &&
|
|
||||||
hwy::IsSameEither<VecT, float, hwy::bfloat16_t>() &&
|
|
||||||
!(hwy::IsSame<MatT, SfpStream>() &&
|
|
||||||
hwy::IsSame<VecT, float>())) {
|
|
||||||
ToEvenOddF32(vec_aligned, kInner, even_odd);
|
|
||||||
detail::MatVecAddInner</*kVecIsEvenOdd=*/true, kAdd, kOuter, kInner>(
|
|
||||||
mat, mat_ofs, even_odd, add, out, pool);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
(void)even_odd;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
detail::MatVecAddInner</*kVecIsEvenOdd=*/false, kAdd, kOuter, kInner>(
|
// For each entire strip.
|
||||||
mat, mat_ofs, vec_aligned, add, out, pool);
|
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||||
|
PROFILER_ZONE("MatVec.lambda");
|
||||||
|
const size_t r0 = strip * kRowsPerStrip;
|
||||||
|
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
|
||||||
|
kRowsPerStrip, vec_aligned, add,
|
||||||
|
out + r0);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Remaining rows
|
||||||
|
const size_t r0 = kNumStrips * kRowsPerStrip;
|
||||||
|
if (r0 < kOuter) {
|
||||||
|
PROFILER_ZONE("MatVec remainder");
|
||||||
|
const size_t num_rows = kOuter - r0;
|
||||||
|
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
|
||||||
|
num_rows, vec_aligned, add, out + r0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// With addition
|
// With addition
|
||||||
|
|
@ -268,21 +203,19 @@ template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT,
|
||||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const VecT* HWY_RESTRICT const vec_aligned,
|
const VecT* HWY_RESTRICT const vec_aligned,
|
||||||
const AddT* HWY_RESTRICT const add,
|
const AddT* HWY_RESTRICT const add,
|
||||||
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
|
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
return MatVecT</*kAdd=*/true, kOuter, kInner>(mat, mat_ofs, vec_aligned, add,
|
return MatVecT</*kAdd=*/true, kOuter, kInner>(mat, mat_ofs, vec_aligned, add,
|
||||||
even_odd, out, pool);
|
out, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Without addition
|
// Without addition
|
||||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||||
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const VecT* HWY_RESTRICT const vec_aligned,
|
const VecT* HWY_RESTRICT const vec_aligned,
|
||||||
float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out,
|
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
MatVecT</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned,
|
MatVecT</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned,
|
||||||
/*add=*/static_cast<VecT*>(nullptr),
|
/*add=*/static_cast<VecT*>(nullptr),
|
||||||
even_odd, out, pool);
|
out, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Two matrices, same vector
|
// Two matrices, same vector
|
||||||
|
|
@ -300,18 +233,17 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT& mat0, const ArrayT& mat1,
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||||
constexpr bool kVecIsEvenOdd = false;
|
|
||||||
|
|
||||||
// For each entire strip.
|
// For each entire strip.
|
||||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||||
PROFILER_ZONE("TwoMatVec.lambda");
|
PROFILER_ZONE("TwoMatVec.lambda");
|
||||||
const size_t r0 = strip * kRowsPerStrip;
|
const size_t r0 = strip * kRowsPerStrip;
|
||||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, kInner, r0,
|
||||||
df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0,
|
kRowsPerStrip, vec_aligned, add0,
|
||||||
out0 + r0);
|
out0 + r0);
|
||||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, kInner, r0,
|
||||||
df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1,
|
kRowsPerStrip, vec_aligned, add1,
|
||||||
out1 + r0);
|
out1 + r0);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Remaining rows
|
// Remaining rows
|
||||||
|
|
@ -319,9 +251,9 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT& mat0, const ArrayT& mat1,
|
||||||
if (r0 < kOuter) {
|
if (r0 < kOuter) {
|
||||||
PROFILER_ZONE("TwoMatVec remainder");
|
PROFILER_ZONE("TwoMatVec remainder");
|
||||||
const size_t num_rows = kOuter - r0;
|
const size_t num_rows = kOuter - r0;
|
||||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
detail::FullDotProductsForStrip<kAdd>(
|
||||||
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
|
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
|
||||||
detail::FullDotProductsForStrip<kVecIsEvenOdd, kAdd>(
|
detail::FullDotProductsForStrip<kAdd>(
|
||||||
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
|
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue