From 5cb63346aa7a328f7f0bc3505f78db56c1e5f80a Mon Sep 17 00:00:00 2001 From: Sam Kaufman Date: Mon, 29 Apr 2024 12:51:35 -0700 Subject: [PATCH] supports_eo -> kSupportsEvenOdd --- compression/compress-inl.h | 8 ++++---- gemma/ops.h | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 631dbd4..7ae43b7 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -72,7 +72,7 @@ struct CompressTraits {}; template <> struct CompressTraits { using MatT = float; - static constexpr bool supports_eo = false; + static constexpr bool kSupportsEvenOdd = false; template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, @@ -126,7 +126,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = hwy::bfloat16_t; - static constexpr bool supports_eo = true; + static constexpr bool kSupportsEvenOdd = true; template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, @@ -288,7 +288,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = SfpStream; - static constexpr bool supports_eo = false; + static constexpr bool kSupportsEvenOdd = false; template static HWY_INLINE void Compress(DF df, const float* in, size_t num, @@ -338,7 +338,7 @@ struct CompressTraits { template <> struct CompressTraits { using MatT = NuqStream; - static constexpr bool supports_eo = false; + static constexpr bool kSupportsEvenOdd = false; template static HWY_INLINE void Compress(DF df, const float* in, size_t num, diff --git a/gemma/ops.h b/gemma/ops.h index 9fa79af..1c09409 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -341,7 +341,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, // vector to even-odd layout. template ::supports_eo, bool> = true> + std::enable_if_t< + CompressTraits::kSupportsEvenOdd, bool> + = true> HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, const float* HWY_RESTRICT const vec_aligned, const AddT* HWY_RESTRICT const add, @@ -378,7 +380,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, // vector to even-odd layout. template ::supports_eo, bool> = true> + std::enable_if_t< + CompressTraits::kSupportsEvenOdd, bool> + = true> HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned, const AddT* HWY_RESTRICT const add,