diff --git a/BUILD.bazel b/BUILD.bazel index 3c6ec5d..64c68f2 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -367,6 +367,7 @@ cc_test( ":test_util", ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep + "//compression:types", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:nanobenchmark", #buildcleaner: keep @@ -430,7 +431,7 @@ cc_test( ], deps = [ ":basics", - ":ops", + ":matmul", ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", @@ -500,10 +501,10 @@ cc_library( ":matmul", ":model_store", ":ops", - ":tokenizer", ":threading", ":threading_context", ":weights", + "//compression:types", "//io:blob_store", "//io", "//paligemma:image", @@ -521,6 +522,7 @@ cc_library( deps = [ ":gemma_lib", ":ops", + "//compression:types", "@highway//:hwy", ], ) diff --git a/compression/compress_test.cc b/compression/compress_test.cc index ee2db4c..2270689 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests. +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE) -#endif +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS #include "compression/compress.h" diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index a40624b..df300f4 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests. +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE) -#endif +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS #include #include @@ -27,7 +27,6 @@ #include #include "compression/distortion.h" -#include "compression/types.h" #include "util/test_util.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 5051a87..428e9cf 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -15,6 +15,11 @@ #include "compression/python/compression_clif_aux.h" +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + #include #include #include diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index 5d97caa..8e49ceb 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// We use ConcatEven/Odd which are not supported. Use HWY_EMU128 instead. +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS HWY_SCALAR -#endif +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS #include #include @@ -25,7 +25,6 @@ #include #include "compression/distortion.h" -#include "compression/types.h" #include "util/test_util.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" diff --git a/compression/types.h b/compression/types.h index 015560e..dc10676 100644 --- a/compression/types.h +++ b/compression/types.h @@ -29,6 +29,30 @@ namespace gcpp { +// EMU128 must not be disabled because we disable SCALAR. +#define HWY_BROKEN_EMU128 0 + +// Allow user override of disabled targets. +#ifndef GEMMA_DISABLED_TARGETS + +// All platforms: exclude SCALAR because we use ReorderWidenMulAccumulate. + +#if HWY_ARCH_ARM_V7 +// No NEON because we require double-precision support. +#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_ALL_NEON) +#elif HWY_ARCH_ARM_A64 +// We do not yet use AES (e.g. for random generation), hence NEON is the same +// as NEON_WITHOUT_AES. Also skip SVE because SVE2_128 and SVE_256 cover most. +#define GEMMA_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON | HWY_SVE) +#elif HWY_ARCH_X86 +// Skip anything older than Haswell (2013); also use Zen4 for recent CPUs, +// because we do not use anything added by SPR (e.g. FP16) nor AVX 10.2. +#define GEMMA_DISABLED_TARGETS \ + (HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX3_SPR | HWY_AVX10_2) +#endif // HWY_ARCH_* + +#endif // GEMMA_DISABLED_TARGETS + // Only used in experiments, hence disable in default builds. #ifndef GEMMA_ENABLE_NUQ #define GEMMA_ENABLE_NUQ 0 diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 4c64f2e..e1a6ff4 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -13,6 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off diff --git a/gemma/attention.cc b/gemma/attention.cc index a26db39..ceb80ad 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -18,6 +18,11 @@ #include +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + #include "gemma/activations.h" #include "gemma/gemma.h" #include "gemma/weights.h" diff --git a/gemma/gemma.cc b/gemma/gemma.cc index d635dbb..b8a625d 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -18,6 +18,11 @@ #include "gemma/gemma.h" +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off diff --git a/gemma/griffin.cc b/gemma/griffin.cc index e7a3a75..35bf29a 100644 --- a/gemma/griffin.cc +++ b/gemma/griffin.cc @@ -16,6 +16,11 @@ #include #include +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + #include "gemma/activations.h" #include "gemma/gemma.h" #include "gemma/gemma_args.h" diff --git a/gemma/vit.cc b/gemma/vit.cc index ddbd963..0231a5f 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -19,6 +19,11 @@ #include +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + #include "gemma/activations.h" #include "gemma/gemma.h" #include "gemma/gemma_args.h" diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index b60a258..949f445 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -16,24 +16,17 @@ // Benchmark of large MatMul instances for which the MatMulSlow would be too // slow. This lacks a reference and is only useful for performance measurement. -#include "hwy/base.h" -#ifndef HWY_DISABLED_TARGETS -// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require -// double-precision support. -#if HWY_ARCH_ARM_V7 -#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON) -#else -#define HWY_DISABLED_TARGETS HWY_SCALAR -#endif -#endif - #include #include #include #include -#include "compression/types.h" +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + #include "ops/matmul.h" #include "util/basics.h" #include "util/threading_context.h" diff --git a/ops/dot_test.cc b/ops/dot_test.cc index f210de5..4f0c94d 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -// Exclude HWY_SCALAR due to 2x bf16 -> f32. -#define HWY_DISABLED_TARGETS HWY_SCALAR -#endif +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS #include #include @@ -27,7 +27,6 @@ #include #include "compression/compress.h" -#include "compression/types.h" #include "util/allocator.h" #include "util/test_util.h" #include "util/threading_context.h" diff --git a/ops/gemma_matvec_test.cc b/ops/gemma_matvec_test.cc index 37fc505..0ca58b7 100644 --- a/ops/gemma_matvec_test.cc +++ b/ops/gemma_matvec_test.cc @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -// Exclude HWY_SCALAR due to 2x bf16 -> f32. -#define HWY_DISABLED_TARGETS HWY_SCALAR -#endif +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS #include #include diff --git a/ops/matmul_static_bf16.cc b/ops/matmul_static_bf16.cc index 02aa398..f84d1d7 100644 --- a/ops/matmul_static_bf16.cc +++ b/ops/matmul_static_bf16.cc @@ -13,6 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off diff --git a/ops/matmul_static_f32.cc b/ops/matmul_static_f32.cc index 625e5b5..e749f53 100644 --- a/ops/matmul_static_f32.cc +++ b/ops/matmul_static_f32.cc @@ -13,6 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off diff --git a/ops/matmul_static_nuq.cc b/ops/matmul_static_nuq.cc index 80d8481..c9e75e7 100644 --- a/ops/matmul_static_nuq.cc +++ b/ops/matmul_static_nuq.cc @@ -13,6 +13,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "compression/types.h" // GEMMA_ENABLE_NUQ +#if GEMMA_ENABLE_NUQ + +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off @@ -22,3 +29,5 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep #define GEMMA_MATMUL_TB NuqStream #include "ops/matmul_static-inl.h" + +#endif // GEMMA_ENABLE_NUQ diff --git a/ops/matmul_static_sfp.cc b/ops/matmul_static_sfp.cc index c61fcb1..a2fca74 100644 --- a/ops/matmul_static_sfp.cc +++ b/ops/matmul_static_sfp.cc @@ -13,6 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 112576a..fb33fb4 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -15,16 +15,11 @@ // End to end test of MatMul, comparing against a reference implementation. -#include "hwy/detect_compiler_arch.h" // IWYU pragma: keep +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require -// double-precision support. -#if HWY_ARCH_ARM_V7 -#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON) -#else -#define HWY_DISABLED_TARGETS (HWY_SCALAR) -#endif // HWY_ARCH_ARM_V7 +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS + // matmul_static is not built as a test, hence does not define MatMulStatic for // worse-than-baseline targets (to speed up builds), so we skip them here, too. #ifndef HWY_SKIP_NON_BEST_BASELINE @@ -34,7 +29,6 @@ #include #include -#include "compression/types.h" #include "ops/matmul.h" #include "util/basics.h" #include "util/mat.h" diff --git a/ops/ops_test.cc b/ops/ops_test.cc index a0ff314..c424d7a 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// OrderedDemote2To is not supported by HWY_SCALAR. +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS HWY_SCALAR -#endif +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS #include "ops/ops.h"