diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 630e506542..22de55700d 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { float32x4_t tmp = x[0] + vec_reve(x[0]); \ res = tmp[0] + tmp[1]; \ } +#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \ +{ \ + float32x4_t v = vec_add(vec_add(s0, s1), \ + vec_add(s2, s3)); \ + v = vec_add(v, vec_sld(v, v, 8)); \ + v = vec_add(v, vec_sld(v, v, 4)); \ + res += (ggml_float)vec_extract(v, 0); \ +} #define GGML_F32_VEC GGML_F32x4 #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO @@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { #define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// BF16 s390x +#define GGML_BF16_STEP 16 +#define GGML_BF16_EPR 8 + +#define GGML_BF16x8 __vector unsigned short +#define GGML_BF16x8_ZERO vec_splats((unsigned short)0) +#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p)) + +#define GGML_BF16_VEC GGML_BF16x8 +#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO +#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD +#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_FMA_LO(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y)) +#define GGML_BF16_FMA_HI(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y)) + #elif defined(__riscv_v_intrinsic) // compatible with vlen >= 128 diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 8708cd4e92..d0e4001338 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); sumf += __riscv_vfmv_f_s_f32m1_f32(redsum); -#endif -#if defined(__POWER9_VECTOR__) +#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__) const int np = (n & ~(GGML_BF16_STEP - 1)); if (np > 0) { GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};