From 627cc04db99d4281bdfc8cd5fa3aa79395584982 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 27 May 2025 07:08:23 -0700 Subject: [PATCH] Decouple MatMul from gemma-inl: precompile for all input types Call MatMulStatic instead of MatMul. Also fix build error due to Highway's Lanes not being constexpr. PiperOrigin-RevId: 763777269 --- BUILD.bazel | 33 +++++++++++++ CMakeLists.txt | 6 ++- MODULE.bazel | 2 +- examples/hello_world/CMakeLists.txt | 2 +- examples/simplified_gemma/CMakeLists.txt | 2 +- gemma/gemma-inl.h | 61 ++++++++++++------------ gemma/instantiations/bf16.cc | 1 + ops/fp_arith-inl.h | 2 +- ops/matmul_static-inl.h | 60 +++++++++++++++++++++++ ops/matmul_static.h | 59 +++++++++++++++++++++++ ops/matmul_static_bf16.cc | 24 ++++++++++ ops/matmul_static_f32.cc | 24 ++++++++++ ops/matmul_static_nuq.cc | 24 ++++++++++ ops/matmul_static_sfp.cc | 24 ++++++++++ ops/matmul_test.cc | 36 ++++++++------ 15 files changed, 311 insertions(+), 49 deletions(-) create mode 100644 ops/matmul_static-inl.h create mode 100644 ops/matmul_static.h create mode 100644 ops/matmul_static_bf16.cc create mode 100644 ops/matmul_static_f32.cc create mode 100644 ops/matmul_static_nuq.cc create mode 100644 ops/matmul_static_sfp.cc diff --git a/BUILD.bazel b/BUILD.bazel index b673341..5502169 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -295,6 +295,35 @@ cc_library( ], ) +cc_library( + name = "matmul", + srcs = [ + # single-file build time is ~30sec for msan, hence shard. + "ops/matmul_static_bf16.cc", + "ops/matmul_static_f32.cc", + "ops/matmul_static_nuq.cc", + "ops/matmul_static_sfp.cc", + ], + hdrs = [ + "ops/matmul_static.h", + ], + textual_hdrs = [ + "ops/matmul_static-inl.h", + "ops/matmul-inl.h", + ], + deps = [ + ":allocator", + ":basics", + ":ops", + ":threading_context", + "//compression:compress", + "//compression:types", + "@highway//:hwy", + "@highway//:profiler", + "@highway//:timer", + ], +) + cc_test( name = "dot_test", size = "small", @@ -373,6 +402,7 @@ cc_test( deps = [ ":basics", ":mat", + ":matmul", ":ops", ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep @@ -380,6 +410,7 @@ cc_test( "//compression:test_util", "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:nanobenchmark", "@highway//:thread_pool", ], ) @@ -463,6 +494,7 @@ cc_library( ":gemma_args", ":kv_cache", ":mat", + ":matmul", ":model_store", ":ops", ":tokenizer", @@ -578,6 +610,7 @@ cc_test( "@googletest//:gtest_main", # buildcleaner: keep "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:profiler", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 67ca93e..97f4ccb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06 EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if @@ -95,6 +95,10 @@ set(SOURCES io/io.cc io/io.h ops/dot-inl.h + ops/matmul_static_bf16.cc + ops/matmul_static_f32.cc + ops/matmul_static_nuq.cc + ops/matmul_static_sfp.cc ops/matmul-inl.h ops/matmul.cc ops/matmul.h diff --git a/MODULE.bazel b/MODULE.bazel index 77690fa..e27cfa0 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. git_override( module_name = "highway", - commit = "c5bebf84ad01edec97e336f5c97ca4e0df6b4d06", + commit = "12d9fa908e0c1d3346c298d472584687a24e4ce6", remote = "https://github.com/google/highway", ) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 2eb39c6..2e1c36e 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -17,7 +17,7 @@ project(hello_world) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) diff --git a/examples/simplified_gemma/CMakeLists.txt b/examples/simplified_gemma/CMakeLists.txt index e7e6653..b6af3e8 100644 --- a/examples/simplified_gemma/CMakeLists.txt +++ b/examples/simplified_gemma/CMakeLists.txt @@ -17,7 +17,7 @@ project(simplified_gemma) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 12d9fa908e0c1d3346c298d472584687a24e4ce6) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index a1ed465..4e7912f 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -29,6 +29,7 @@ #include "gemma/gemma.h" #include "gemma/kv_cache.h" #include "gemma/weights.h" +#include "ops/matmul_static.h" #include "paligemma/image.h" #include "util/mat.h" #include "util/threading_context.h" @@ -48,7 +49,6 @@ #include "hwy/highway.h" // After highway.h -#include "ops/matmul-inl.h" #include "ops/matvec-inl.h" #include "ops/ops-inl.h" @@ -266,8 +266,9 @@ class GemmaAttention { // computed in the second MatMul. const size_t w1_rows = heads * layer_config_.QStride(); HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows); - MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1, - /*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q)); + MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1, + /*add=*/nullptr, *activations_.env, + RowPtrFromMat(activations_.q)); if (is_mha_) { // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. @@ -285,8 +286,8 @@ class GemmaAttention { float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; RowPtrF kv_rows(kv, w_rows_kv_cols); kv_rows.SetStride(cache_pos_size_); - MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2, - /*add=*/nullptr, *activations_.env, kv_rows); + MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2, + /*add=*/nullptr, *activations_.env, kv_rows); } else { // Proceed row by row because there will be wraparound. for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; @@ -489,8 +490,8 @@ class GemmaAttention { layer_weights_.layer_config.softmax_attn_output_biases ? layer_weights_.attention_output_biases.PackedScale1() : nullptr; - MatMul(activations_.att_out, layer_weights_.att_weights, add, - *activations_.env, RowPtrFromMat(activations_.att_sums)); + MatMulStatic(activations_.att_out, layer_weights_.att_weights, add, + *activations_.env, RowPtrFromMat(activations_.att_sums)); } public: @@ -625,9 +626,9 @@ class VitAttention { auto& qkv = activations_.q; HWY_ASSERT(qkv.Rows() == num_tokens_); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); - MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w, - layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, - RowPtrFromMat(qkv)); + MatMulStatic(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w, + layer_weights_.vit.qkv_einsum_b.PackedScale1(), + *activations_.env, RowPtrFromMat(qkv)); } // TODO(philculliton): transition fully to MatMul. @@ -667,7 +668,7 @@ class VitAttention { }); // this produces C, a (num_tokens_, seq_len) matrix of dot products - MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(C)); + MatMulStatic(Q, K, nullptr, *activations_.env, RowPtrFromMat(C)); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { float* HWY_RESTRICT c = C.Row(task); @@ -734,8 +735,8 @@ class VitAttention { // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. auto att_sums = RowPtrFromMat(activations_.att_sums); - MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias, - *activations_.env, att_sums); + MatMulStatic(activations_.att_out, layer_weights_.vit.attn_out_w, bias, + *activations_.env, att_sums); } public: @@ -826,18 +827,18 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; // Compute the hidden layer activations. - MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, bias1, - *activations.env, RowPtrFromMat(activations.C1)); - MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, bias2, - *activations.env, RowPtrFromMat(activations.C2)); + MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, + bias1, *activations.env, RowPtrFromMat(activations.C1)); + MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, + bias2, *activations.env, RowPtrFromMat(activations.C2)); // Activation (Gelu) and maybe multiply by gate. Store activations in act. ActivationBatched(layer_weights->layer_config.activation, activations.C1, &activations.C2); // Hidden layer -> output layer. - MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env, - RowPtrFromMat(activations.ffw_out)); + MatMulStatic(activations.C1, layer_weights->linear_w, output_bias, + *activations.env, RowPtrFromMat(activations.ffw_out)); } // Same as FFWNoVit, but with different layer_weights members and no second @@ -854,15 +855,15 @@ HWY_NOINLINE void FFWVit(Activations& activations, add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr; // Compute the hidden layer activations. - MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1, - *activations.env, RowPtrFromMat(activations.C1)); + MatMulStatic(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, + bias1, *activations.env, RowPtrFromMat(activations.C1)); // Activation (Gelu), store in C1. ActivationBatched(layer_weights->layer_config.activation, activations.C1); // Hidden layer -> output layer. - MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias, - *activations.env, RowPtrFromMat(activations.ffw_out)); + MatMulStatic(activations.C1, layer_weights->vit.linear_1_w, output_bias, + *activations.env, RowPtrFromMat(activations.ffw_out)); } // `batch_idx` indicates which row of `x` to write to. @@ -1175,7 +1176,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, // MatStorageT image_patches("patches", Extents2D(kSeqLen, // kPatchSize), MatPadding::kPacked); // [Get patches] - // MatMul( + // MatMulStatic( // MatFromBatch(kVitSeqLen, image_patches), // MatFromWeights(weights.vit_img_embedding_kernel), // weights.vit_img_embedding_bias.PackedScale1(), *activations.env, @@ -1226,9 +1227,9 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, } // Apply head embedding into image_tokens of size of the LLM kModelDim. - MatMul(activations.x, weights.vit_img_head_kernel, - weights.vit_img_head_bias.PackedScale1(), *activations.env, - RowPtrFromMat(image_tokens)); + MatMulStatic(activations.x, weights.vit_img_head_kernel, + weights.vit_img_head_bias.PackedScale1(), *activations.env, + RowPtrFromMat(image_tokens)); } // Generates one token for each query. `queries_token` is the previous token @@ -1366,9 +1367,9 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs& weights, { PROFILER_ZONE("Gen.EmbeddingMatmul"); // Compute logits from last layer activations. - MatMul(activations.x, weights.embedder_input_embedding, - /*add=*/nullptr, *activations.env, - RowPtrFromMat(activations.logits)); + MatMulStatic(activations.x, weights.embedder_input_embedding, + /*add=*/nullptr, *activations.env, + RowPtrFromMat(activations.logits)); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { diff --git a/gemma/instantiations/bf16.cc b/gemma/instantiations/bf16.cc index 19ae585..2f001fb 100644 --- a/gemma/instantiations/bf16.cc +++ b/gemma/instantiations/bf16.cc @@ -15,6 +15,7 @@ #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "gemma/instantiations/bf16.cc" + #include "hwy/foreach_target.h" // IWYU pragma: keep #define GEMMA_TYPE hwy::bfloat16_t #include "gemma/gemma-inl.h" diff --git a/ops/fp_arith-inl.h b/ops/fp_arith-inl.h index 423b0fb..2abae43 100644 --- a/ops/fp_arith-inl.h +++ b/ops/fp_arith-inl.h @@ -136,7 +136,7 @@ hn::TFromD ReduceCascadedSums(DF df, const VF sum, VF sum_err) { using TF = hn::TFromD; // For non-scalable wide vectors, reduce loop iterations below by recursing // once or twice for halves of 256-bit or 512-bit vectors. - if constexpr (!HWY_HAVE_SCALABLE) { + if constexpr (HWY_HAVE_CONSTEXPR_LANES) { if constexpr (hn::Lanes(df) > 16 / sizeof(TF)) { const hn::Half dfh; using VFH = hn::Vec; diff --git a/ops/matmul_static-inl.h b/ops/matmul_static-inl.h new file mode 100644 index 0000000..da17c51 --- /dev/null +++ b/ops/matmul_static-inl.h @@ -0,0 +1,60 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/detect_compiler_arch.h" // HWY_IDE + +#ifndef GEMMA_MATMUL_TB +#if HWY_IDE +// Provide a definition so the IDE does not complain. +#define GEMMA_MATMUL_TB float +#else +#error "Only include from matmul_static_*.cc, which define GEMMA_MATMUL_TB" +#endif // HWY_IDE +#endif // GEMMA_MATMUL_TB + +// Passed to GEMMA_MATMUL_FOREACH_AC; defines one overload for one target. +#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \ + MMPerKey* MatMulStatic(const MatPtrT& A, const MatPtrT& B, \ + const float* HWY_RESTRICT add, MatMulEnv& env, \ + const RowPtr& C) { \ + return MatMul(A, B, add, env, C); \ + } + +#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_ +#undef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_ +#else +#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_ +#endif + +#include "hwy/highway.h" +// After highway.h +#include "ops/matmul-inl.h" +#include "ops/matmul_static.h" // includes highway.h! + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// Ignore warning that we are defining a function in a header; this is only +// included from matmul_static_*.cc. +GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_MATMUL_STATIC_INL_H_ diff --git a/ops/matmul_static.h b/ops/matmul_static.h new file mode 100644 index 0000000..e16d340 --- /dev/null +++ b/ops/matmul_static.h @@ -0,0 +1,59 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_STATIC_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_STATIC_H_ + +// Declares overloads of MatMulStatic for all SIMD targets and input types. + +#include +#include + +#include "ops/matmul.h" // IWYU pragma: keep, b/420428845 +#include "hwy/highway.h" + +// Invokes GEMMA_X(TA, TB, TC) for all combinations of F32 or BF16. +#define GEMMA_MATMUL_FOREACH_AC(GEMMA_X, TB) \ + GEMMA_X(float, TB, float) \ + GEMMA_X(float, TB, BF16) \ + GEMMA_X(BF16, TB, float) \ + GEMMA_X(BF16, TB, BF16) + +// Passed to GEMMA_MATMUL_FOREACH_AC; declares one overload for one target. +#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \ + MMPerKey* MatMulStatic(const MatPtrT& A, const MatPtrT& B, \ + const float* HWY_RESTRICT add, MatMulEnv& env, \ + const RowPtr& C); + +// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets. +#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, BF16) \ + GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, float) \ + GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, NuqStream) \ + GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, SfpStream) \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +namespace gcpp { + +// MatMul function declarations for each SIMD target. Allows direct call from +// the per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_MATMUL_DECL) + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_STATIC_H_ diff --git a/ops/matmul_static_bf16.cc b/ops/matmul_static_bf16.cc new file mode 100644 index 0000000..02aa398 --- /dev/null +++ b/ops/matmul_static_bf16.cc @@ -0,0 +1,24 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops/matmul_static_bf16.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_MATMUL_TB BF16 +#include "ops/matmul_static-inl.h" diff --git a/ops/matmul_static_f32.cc b/ops/matmul_static_f32.cc new file mode 100644 index 0000000..625e5b5 --- /dev/null +++ b/ops/matmul_static_f32.cc @@ -0,0 +1,24 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops/matmul_static_f32.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_MATMUL_TB float +#include "ops/matmul_static-inl.h" diff --git a/ops/matmul_static_nuq.cc b/ops/matmul_static_nuq.cc new file mode 100644 index 0000000..80d8481 --- /dev/null +++ b/ops/matmul_static_nuq.cc @@ -0,0 +1,24 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops/matmul_static_nuq.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_MATMUL_TB NuqStream +#include "ops/matmul_static-inl.h" diff --git a/ops/matmul_static_sfp.cc b/ops/matmul_static_sfp.cc new file mode 100644 index 0000000..c61fcb1 --- /dev/null +++ b/ops/matmul_static_sfp.cc @@ -0,0 +1,24 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops/matmul_static_sfp.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_MATMUL_TB SfpStream +#include "ops/matmul_static-inl.h" diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 0f5974b..69ecc6e 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -18,13 +18,18 @@ #include "hwy/detect_compiler_arch.h" // IWYU pragma: keep #ifndef HWY_DISABLED_TARGETS // Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require -// double-precision support, and older x86 to speed up builds. +// double-precision support. #if HWY_ARCH_ARM_V7 #define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON) #else -#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SSSE3 | HWY_SSE4) -#endif -#endif +#define HWY_DISABLED_TARGETS (HWY_SCALAR) +#endif // HWY_ARCH_ARM_V7 +#endif // HWY_DISABLED_TARGETS +// matmul_static is not built as a test, hence does not define MatMulStatic for +// worse-than-baseline targets (to speed up builds), so we skip them here, too. +#ifndef HWY_SKIP_NON_BEST_BASELINE +#define HWY_SKIP_NON_BEST_BASELINE +#endif // HWY_SKIP_NON_BEST_BASELINE #include #include @@ -35,6 +40,7 @@ #include "util/mat.h" #include "util/threading_context.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/nanobenchmark.h" // Unpredictable1 // clang-format off #undef HWY_TARGET_INCLUDE @@ -46,7 +52,7 @@ #include "compression/compress-inl.h" #include "compression/test_util-inl.h" #include "ops/dot-inl.h" -#include "ops/matmul-inl.h" +#include "ops/matmul_static.h" // also textual HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -234,7 +240,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulSlow(a, b_trans, add_row, env, C_slow); // A few reps to get coverage of the various autotuned code paths. for (size_t rep = 0; rep < 16; ++rep) { - MMPerKey* per_key = MatMul(a, b_trans, add_row, env, C); + MMPerKey* per_key = MatMulStatic(a, b_trans, add_row, env, C); AssertClose(a, b_trans, C_slow, C, line); if (per_key->autotune.Best()) break; } @@ -258,12 +264,12 @@ void TestTiny() { MatMulEnv env(ThreadingContext::Get()); NestedPools& pools = env.ctx.pools; -#if GEMMA_DISABLE_TOPOLOGY - if (max_packages == 2) break; // we only have one package -#else - // If less than the limit, we have already tested all num_packages. - if (env.ctx.topology.FullTopology().packages.size() < max_packages) break; -#endif + if constexpr (GEMMA_DISABLE_TOPOLOGY) { + if (max_packages == 2) break; // we only have one package + } else { + // If less than the limit, we have already tested all num_packages. + if (env.ctx.topology.FullTopology().packages.size() < max_packages) break; + } fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages, env.ctx.topology.TopologyString(), pools.PinString()); @@ -282,8 +288,10 @@ void TestTiny() { void TestAllMatMul() { // Skip EMU128 (10x slower than SSE4 for SFP) and older x86. - if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 || - HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) { + // Add Unpredictable1 to prevent erroneous "unreachable code" warning. + if (hwy::Unpredictable1() == 1 && + (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 || + HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2)) { return; }