Decouple MatMul from gemma-inl: precompile for all input types

Call MatMulStatic instead of MatMul.

Also fix build error due to Highway's Lanes not being constexpr.

PiperOrigin-RevId: 763777269
This commit is contained in:
Jan Wassenberg 2025-05-27 07:08:23 -07:00 committed by Copybara-Service
parent 421a2ab8ac
commit 627cc04db9
15 changed files with 311 additions and 49 deletions

View File

@ -295,6 +295,35 @@ cc_library(
],
)
cc_library(
name = "matmul",
srcs = [
# single-file build time is ~30sec for msan, hence shard.
"ops/matmul_static_bf16.cc",
"ops/matmul_static_f32.cc",
"ops/matmul_static_nuq.cc",
"ops/matmul_static_sfp.cc",
],
hdrs = [
"ops/matmul_static.h",
],
textual_hdrs = [
"ops/matmul_static-inl.h",
"ops/matmul-inl.h",
],
deps = [
":allocator",
":basics",
":ops",
":threading_context",
"//compression:compress",
"//compression:types",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:timer",
],
)
cc_test(
name = "dot_test",
size = "small",
@ -373,6 +402,7 @@ cc_test(
deps = [
":basics",
":mat",
":matmul",
":ops",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
@ -380,6 +410,7 @@ cc_test(
"//compression:test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
],
)
@ -463,6 +494,7 @@ cc_library(
":gemma_args",
":kv_cache",
":mat",
":matmul",
":model_store",
":ops",
":tokenizer",
@ -578,6 +610,7 @@ cc_test(
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:profiler",
],
)

View File

@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06 EXCLUDE_FROM_ALL)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway)
## Note: absl needs to be installed by sentencepiece. This will only happen if
@ -95,6 +95,10 @@ set(SOURCES
io/io.cc
io/io.h
ops/dot-inl.h
ops/matmul_static_bf16.cc
ops/matmul_static_f32.cc
ops/matmul_static_nuq.cc
ops/matmul_static_sfp.cc
ops/matmul-inl.h
ops/matmul.cc
ops/matmul.h

View File

@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version.
git_override(
module_name = "highway",
commit = "c5bebf84ad01edec97e336f5c97ca4e0df6b4d06",
commit = "12d9fa908e0c1d3346c298d472584687a24e4ce6",
remote = "https://github.com/google/highway",
)

View File

@ -17,7 +17,7 @@ project(hello_world)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6)
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)

View File

@ -17,7 +17,7 @@ project(simplified_gemma)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6)
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)

View File

@ -29,6 +29,7 @@
#include "gemma/gemma.h"
#include "gemma/kv_cache.h"
#include "gemma/weights.h"
#include "ops/matmul_static.h"
#include "paligemma/image.h"
#include "util/mat.h"
#include "util/threading_context.h"
@ -48,7 +49,6 @@
#include "hwy/highway.h"
// After highway.h
#include "ops/matmul-inl.h"
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
@ -266,8 +266,9 @@ class GemmaAttention {
// computed in the second MatMul.
const size_t w1_rows = heads * layer_config_.QStride();
HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows);
MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
/*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q));
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
/*add=*/nullptr, *activations_.env,
RowPtrFromMat(activations_.q));
if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
@ -285,8 +286,8 @@ class GemmaAttention {
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
RowPtrF kv_rows(kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_);
MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
/*add=*/nullptr, *activations_.env, kv_rows);
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
/*add=*/nullptr, *activations_.env, kv_rows);
} else {
// Proceed row by row because there will be wraparound.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
@ -489,8 +490,8 @@ class GemmaAttention {
layer_weights_.layer_config.softmax_attn_output_biases
? layer_weights_.attention_output_biases.PackedScale1()
: nullptr;
MatMul(activations_.att_out, layer_weights_.att_weights, add,
*activations_.env, RowPtrFromMat(activations_.att_sums));
MatMulStatic(activations_.att_out, layer_weights_.att_weights, add,
*activations_.env, RowPtrFromMat(activations_.att_sums));
}
public:
@ -625,9 +626,9 @@ class VitAttention {
auto& qkv = activations_.q;
HWY_ASSERT(qkv.Rows() == num_tokens_);
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
RowPtrFromMat(qkv));
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
layer_weights_.vit.qkv_einsum_b.PackedScale1(),
*activations_.env, RowPtrFromMat(qkv));
}
// TODO(philculliton): transition fully to MatMul.
@ -667,7 +668,7 @@ class VitAttention {
});
// this produces C, a (num_tokens_, seq_len) matrix of dot products
MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(C));
MatMulStatic(Q, K, nullptr, *activations_.env, RowPtrFromMat(C));
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Row(task);
@ -734,8 +735,8 @@ class VitAttention {
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads.
auto att_sums = RowPtrFromMat(activations_.att_sums);
MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
*activations_.env, att_sums);
MatMulStatic(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
*activations_.env, att_sums);
}
public:
@ -826,18 +827,18 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
// Compute the hidden layer activations.
MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, bias1,
*activations.env, RowPtrFromMat(activations.C1));
MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, bias2,
*activations.env, RowPtrFromMat(activations.C2));
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1,
bias1, *activations.env, RowPtrFromMat(activations.C1));
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2,
bias2, *activations.env, RowPtrFromMat(activations.C2));
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
&activations.C2);
// Hidden layer -> output layer.
MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env,
RowPtrFromMat(activations.ffw_out));
MatMulStatic(activations.C1, layer_weights->linear_w, output_bias,
*activations.env, RowPtrFromMat(activations.ffw_out));
}
// Same as FFWNoVit, but with different layer_weights members and no second
@ -854,15 +855,15 @@ HWY_NOINLINE void FFWVit(Activations& activations,
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
// Compute the hidden layer activations.
MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1,
*activations.env, RowPtrFromMat(activations.C1));
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w,
bias1, *activations.env, RowPtrFromMat(activations.C1));
// Activation (Gelu), store in C1.
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
// Hidden layer -> output layer.
MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias,
*activations.env, RowPtrFromMat(activations.ffw_out));
MatMulStatic(activations.C1, layer_weights->vit.linear_1_w, output_bias,
*activations.env, RowPtrFromMat(activations.ffw_out));
}
// `batch_idx` indicates which row of `x` to write to.
@ -1175,7 +1176,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
// kPatchSize), MatPadding::kPacked);
// [Get patches]
// MatMul(
// MatMulStatic(
// MatFromBatch(kVitSeqLen, image_patches),
// MatFromWeights(weights.vit_img_embedding_kernel),
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
@ -1226,9 +1227,9 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
}
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul(activations.x, weights.vit_img_head_kernel,
weights.vit_img_head_bias.PackedScale1(), *activations.env,
RowPtrFromMat(image_tokens));
MatMulStatic(activations.x, weights.vit_img_head_kernel,
weights.vit_img_head_bias.PackedScale1(), *activations.env,
RowPtrFromMat(image_tokens));
}
// Generates one token for each query. `queries_token` is the previous token
@ -1366,9 +1367,9 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
MatMul(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, *activations.env,
RowPtrFromMat(activations.logits));
MatMulStatic(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, *activations.env,
RowPtrFromMat(activations.logits));
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {

View File

@ -15,6 +15,7 @@
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/instantiations/bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_TYPE hwy::bfloat16_t
#include "gemma/gemma-inl.h"

View File

@ -136,7 +136,7 @@ hn::TFromD<DF> ReduceCascadedSums(DF df, const VF sum, VF sum_err) {
using TF = hn::TFromD<DF>;
// For non-scalable wide vectors, reduce loop iterations below by recursing
// once or twice for halves of 256-bit or 512-bit vectors.
if constexpr (!HWY_HAVE_SCALABLE) {
if constexpr (HWY_HAVE_CONSTEXPR_LANES) {
if constexpr (hn::Lanes(df) > 16 / sizeof(TF)) {
const hn::Half<DF> dfh;
using VFH = hn::Vec<decltype(dfh)>;

60
ops/matmul_static-inl.h Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "hwy/detect_compiler_arch.h" // HWY_IDE
#ifndef GEMMA_MATMUL_TB
#if HWY_IDE
// Provide a definition so the IDE does not complain.
#define GEMMA_MATMUL_TB float
#else
#error "Only include from matmul_static_*.cc, which define GEMMA_MATMUL_TB"
#endif // HWY_IDE
#endif // GEMMA_MATMUL_TB
// Passed to GEMMA_MATMUL_FOREACH_AC; defines one overload for one target.
#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
const float* HWY_RESTRICT add, MatMulEnv& env, \
const RowPtr<TC>& C) { \
return MatMul(A, B, add, env, C); \
}
#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \
defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_
#undef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_
#else
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_
#endif
#include "hwy/highway.h"
// After highway.h
#include "ops/matmul-inl.h"
#include "ops/matmul_static.h" // includes highway.h!
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Ignore warning that we are defining a function in a header; this is only
// included from matmul_static_*.cc.
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_

59
ops/matmul_static.h Normal file
View File

@ -0,0 +1,59 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_STATIC_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_STATIC_H_
// Declares overloads of MatMulStatic for all SIMD targets and input types.
#include <stddef.h>
#include <stdint.h>
#include "ops/matmul.h" // IWYU pragma: keep, b/420428845
#include "hwy/highway.h"
// Invokes GEMMA_X(TA, TB, TC) for all combinations of F32 or BF16.
#define GEMMA_MATMUL_FOREACH_AC(GEMMA_X, TB) \
GEMMA_X(float, TB, float) \
GEMMA_X(float, TB, BF16) \
GEMMA_X(BF16, TB, float) \
GEMMA_X(BF16, TB, BF16)
// Passed to GEMMA_MATMUL_FOREACH_AC; declares one overload for one target.
#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
const float* HWY_RESTRICT add, MatMulEnv& env, \
const RowPtr<TC>& C);
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
namespace NAMESPACE { \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, BF16) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, float) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, NuqStream) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, SfpStream) \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
namespace gcpp {
// MatMul function declarations for each SIMD target. Allows direct call from
// the per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_MATMUL_DECL)
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_STATIC_H_

24
ops/matmul_static_bf16.cc Normal file
View File

@ -0,0 +1,24 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/matmul_static_bf16.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_MATMUL_TB BF16
#include "ops/matmul_static-inl.h"

24
ops/matmul_static_f32.cc Normal file
View File

@ -0,0 +1,24 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/matmul_static_f32.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_MATMUL_TB float
#include "ops/matmul_static-inl.h"

24
ops/matmul_static_nuq.cc Normal file
View File

@ -0,0 +1,24 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/matmul_static_nuq.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_MATMUL_TB NuqStream
#include "ops/matmul_static-inl.h"

24
ops/matmul_static_sfp.cc Normal file
View File

@ -0,0 +1,24 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/matmul_static_sfp.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_MATMUL_TB SfpStream
#include "ops/matmul_static-inl.h"

View File

@ -18,13 +18,18 @@
#include "hwy/detect_compiler_arch.h" // IWYU pragma: keep
#ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
// double-precision support, and older x86 to speed up builds.
// double-precision support.
#if HWY_ARCH_ARM_V7
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
#else
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SSSE3 | HWY_SSE4)
#endif
#endif
#define HWY_DISABLED_TARGETS (HWY_SCALAR)
#endif // HWY_ARCH_ARM_V7
#endif // HWY_DISABLED_TARGETS
// matmul_static is not built as a test, hence does not define MatMulStatic for
// worse-than-baseline targets (to speed up builds), so we skip them here, too.
#ifndef HWY_SKIP_NON_BEST_BASELINE
#define HWY_SKIP_NON_BEST_BASELINE
#endif // HWY_SKIP_NON_BEST_BASELINE
#include <stddef.h>
#include <stdio.h>
@ -35,6 +40,7 @@
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/nanobenchmark.h" // Unpredictable1
// clang-format off
#undef HWY_TARGET_INCLUDE
@ -46,7 +52,7 @@
#include "compression/compress-inl.h"
#include "compression/test_util-inl.h"
#include "ops/dot-inl.h"
#include "ops/matmul-inl.h"
#include "ops/matmul_static.h" // also textual
HWY_BEFORE_NAMESPACE();
namespace gcpp {
@ -234,7 +240,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulSlow(a, b_trans, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths.
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMul(a, b_trans, add_row, env, C);
MMPerKey* per_key = MatMulStatic(a, b_trans, add_row, env, C);
AssertClose(a, b_trans, C_slow, C, line);
if (per_key->autotune.Best()) break;
}
@ -258,12 +264,12 @@ void TestTiny() {
MatMulEnv env(ThreadingContext::Get());
NestedPools& pools = env.ctx.pools;
#if GEMMA_DISABLE_TOPOLOGY
if (max_packages == 2) break; // we only have one package
#else
// If less than the limit, we have already tested all num_packages.
if (env.ctx.topology.FullTopology().packages.size() < max_packages) break;
#endif
if constexpr (GEMMA_DISABLE_TOPOLOGY) {
if (max_packages == 2) break; // we only have one package
} else {
// If less than the limit, we have already tested all num_packages.
if (env.ctx.topology.FullTopology().packages.size() < max_packages) break;
}
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
env.ctx.topology.TopologyString(), pools.PinString());
@ -282,8 +288,10 @@ void TestTiny() {
void TestAllMatMul() {
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) {
// Add Unpredictable1 to prevent erroneous "unreachable code" warning.
if (hwy::Unpredictable1() == 1 &&
(HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2)) {
return;
}