ggml-cpu/vec: fix #15953 using multiple accum vectors for AVX2/AVX512

Reference: https://github.com/ggml-org/llama.cpp/pull/1595#pullrequestreview-3310928344
This commit is contained in:
Herman Semenoff 2025-12-18 00:38:58 +03:00
parent d6742125c3
commit 61ca2e4d35
No known key found for this signature in database
GPG Key ID: 1D2DC7BDC7225EF7
1 changed files with 43 additions and 9 deletions

View File

@ -414,27 +414,61 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
int i = 0;
ggml_float sum = 0;
// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 sum_v = _mm512_setzero_ps();
for (; i + 63 < n; i += 64) {
__m512 val1 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(mean));
__m512 val2 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(mean));
__m512 val3 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(mean));
__m512 val4 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(mean));
_mm512_storeu_ps(y + i + 0, val1);
_mm512_storeu_ps(y + i + 16, val2);
_mm512_storeu_ps(y + i + 32, val3);
_mm512_storeu_ps(y + i + 48, val4);
sum_v = _mm512_fmadd_ps(val1, val1, sum_v);
sum_v = _mm512_fmadd_ps(val2, val2, sum_v);
sum_v = _mm512_fmadd_ps(val3, val3, sum_v);
sum_v = _mm512_fmadd_ps(val4, val4, sum_v);
}
for (; i + 15 < n; i += 16) {
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
_mm512_set1_ps(mean));
_mm512_storeu_ps(y + i, val);
sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
sum_v = _mm512_fmadd_ps(val, val, sum_v);
}
sum += (ggml_float)_mm512_reduce_add_ps(sum_v);
#elif defined(__AVX2__) && defined(__FMA__)
__m256 sum_v1 = _mm256_setzero_ps();
__m256 sum_v2 = _mm256_setzero_ps();
__m256 sum_v3 = _mm256_setzero_ps();
__m256 sum_v4 = _mm256_setzero_ps();
for (; i + 31 < n; i += 32) {
__m256 val1 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(mean));
__m256 val2 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(mean));
__m256 val3 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(mean));
__m256 val4 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(mean));
_mm256_storeu_ps(y + i + 0, val1);
_mm256_storeu_ps(y + i + 8, val2);
_mm256_storeu_ps(y + i + 16, val3);
_mm256_storeu_ps(y + i + 24, val4);
sum_v1 = _mm256_fmadd_ps(val1, val1, sum_v1);
sum_v2 = _mm256_fmadd_ps(val2, val2, sum_v2);
sum_v3 = _mm256_fmadd_ps(val3, val3, sum_v3);
sum_v4 = _mm256_fmadd_ps(val4, val4, sum_v4);
}
for (; i + 7 < n; i += 8) {
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
_mm256_set1_ps(mean));
_mm256_storeu_ps(y + i, val);
val = _mm256_mul_ps(val,val);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
_mm256_castps256_ps128(val));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
sum_v1 = _mm256_fmadd_ps(val, val, sum_v1);
}
sum_v1 = _mm256_add_ps(sum_v1, sum_v2);
sum_v1 = _mm256_add_ps(sum_v1, sum_v3);
sum_v1 = _mm256_add_ps(sum_v1, sum_v4);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),