diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index be1e3a326e..5aad3305e9 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -525,23 +525,60 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float int i = 0; ggml_float sum = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 sum_v = _mm512_setzero_ps(); + for (; i + 63 < n; i += 64) { + __m512 val1 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(max))); + __m512 val2 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(max))); + __m512 val3 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(max))); + __m512 val4 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(max))); + _mm512_storeu_ps(y + i + 0, val1); + _mm512_storeu_ps(y + i + 16, val2); + _mm512_storeu_ps(y + i + 32, val3); + _mm512_storeu_ps(y + i + 48, val4); + sum_v = _mm512_add_ps(sum_v, val1); + sum_v = _mm512_add_ps(sum_v, val2); + sum_v = _mm512_add_ps(sum_v, val3); + sum_v = _mm512_add_ps(sum_v, val4); + } for (; i + 15 < n; i += 16) { __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i), _mm512_set1_ps(max))); _mm512_storeu_ps(y + i, val); - sum += (ggml_float)_mm512_reduce_add_ps(val); + sum_v = _mm512_add_ps(sum_v, val); } + sum += (ggml_float)_mm512_reduce_add_ps(sum_v); #elif defined(__AVX2__) && defined(__FMA__) + __m256 sum_v1 = _mm256_setzero_ps(); + __m256 sum_v2 = _mm256_setzero_ps(); + __m256 sum_v3 = _mm256_setzero_ps(); + __m256 sum_v4 = _mm256_setzero_ps(); + for (; i + 31 < n; i += 32) { + __m256 val1 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(max))); + __m256 val2 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(max))); + __m256 val3 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(max))); + __m256 val4 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(max))); + _mm256_storeu_ps(y + i + 0, val1); + _mm256_storeu_ps(y + i + 8, val2); + _mm256_storeu_ps(y + i + 16, val3); + _mm256_storeu_ps(y + i + 24, val4); + sum_v1 = _mm256_add_ps(sum_v1, val1); + sum_v2 = _mm256_add_ps(sum_v2, val2); + sum_v3 = _mm256_add_ps(sum_v3, val3); + sum_v4 = _mm256_add_ps(sum_v4, val4); + } for (; i + 7 < n; i += 8) { __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i), _mm256_set1_ps(max))); _mm256_storeu_ps(y + i, val); - __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), - _mm256_castps256_ps128(val)); - val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); - val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); - sum += (ggml_float)_mm_cvtss_f32(val2); + sum_v1 = _mm256_add_ps(sum_v1, val); } + sum_v1 = _mm256_add_ps(sum_v1, sum_v2); + sum_v1 = _mm256_add_ps(sum_v1, sum_v3); + sum_v1 = _mm256_add_ps(sum_v1, sum_v4); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (ggml_float)_mm_cvtss_f32(val2); #elif defined(__SSE2__) for (; i + 3 < n; i += 4) { __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),