mirror of https://github.com/google/gemma.cpp.git
supports_eo -> kSupportsEvenOdd
This commit is contained in:
parent
0816a1070d
commit
5cb63346aa
|
|
@ -72,7 +72,7 @@ struct CompressTraits {};
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<float> {
|
struct CompressTraits<float> {
|
||||||
using MatT = float;
|
using MatT = float;
|
||||||
static constexpr bool supports_eo = false;
|
static constexpr bool kSupportsEvenOdd = false;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||||
|
|
@ -126,7 +126,7 @@ struct CompressTraits<float> {
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<hwy::bfloat16_t> {
|
struct CompressTraits<hwy::bfloat16_t> {
|
||||||
using MatT = hwy::bfloat16_t;
|
using MatT = hwy::bfloat16_t;
|
||||||
static constexpr bool supports_eo = true;
|
static constexpr bool kSupportsEvenOdd = true;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||||
|
|
@ -288,7 +288,7 @@ struct CompressTraits<hwy::bfloat16_t> {
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<SfpStream> {
|
struct CompressTraits<SfpStream> {
|
||||||
using MatT = SfpStream;
|
using MatT = SfpStream;
|
||||||
static constexpr bool supports_eo = false;
|
static constexpr bool kSupportsEvenOdd = false;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
||||||
|
|
@ -338,7 +338,7 @@ struct CompressTraits<SfpStream> {
|
||||||
template <>
|
template <>
|
||||||
struct CompressTraits<NuqStream> {
|
struct CompressTraits<NuqStream> {
|
||||||
using MatT = NuqStream;
|
using MatT = NuqStream;
|
||||||
static constexpr bool supports_eo = false;
|
static constexpr bool kSupportsEvenOdd = false;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
||||||
|
|
|
||||||
|
|
@ -341,7 +341,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
// vector to even-odd layout.
|
// vector to even-odd layout.
|
||||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||||
typename AddT,
|
typename AddT,
|
||||||
std::enable_if_t<CompressTraits<typename ArrayT::value_type>::supports_eo, bool> = true>
|
std::enable_if_t<
|
||||||
|
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool>
|
||||||
|
= true>
|
||||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const float* HWY_RESTRICT const vec_aligned,
|
const float* HWY_RESTRICT const vec_aligned,
|
||||||
const AddT* HWY_RESTRICT const add,
|
const AddT* HWY_RESTRICT const add,
|
||||||
|
|
@ -378,7 +380,9 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
// vector to even-odd layout.
|
// vector to even-odd layout.
|
||||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||||
typename AddT,
|
typename AddT,
|
||||||
std::enable_if_t<CompressTraits<typename ArrayT::value_type>::supports_eo, bool> = true>
|
std::enable_if_t<
|
||||||
|
CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd, bool>
|
||||||
|
= true>
|
||||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned,
|
const hwy::bfloat16_t* HWY_RESTRICT const vec_aligned,
|
||||||
const AddT* HWY_RESTRICT const add,
|
const AddT* HWY_RESTRICT const add,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue