ggml-cpu/vec: rewrite to SIMD with 4 element calc ggml_vec_log_soft_max_f32 for all platforms

This commit is contained in:
Herman Semenoff 2025-12-18 00:52:14 +03:00
parent ea8264950c
commit be23f5f74a
No known key found for this signature in database
GPG Key ID: 1D2DC7BDC7225EF7
1 changed files with 94 additions and 1 deletions

View File

@ -634,10 +634,103 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl
int i = 0;
ggml_float sum = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 sum_v = _mm512_setzero_ps();
for (; i + 63 < n; i += 64) {
__m512 val1 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(max));
__m512 val2 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(max));
__m512 val3 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(max));
__m512 val4 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(max));
_mm512_storeu_ps(y + i + 0, val1);
_mm512_storeu_ps(y + i + 16, val2);
_mm512_storeu_ps(y + i + 32, val3);
_mm512_storeu_ps(y + i + 48, val4);
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val1));
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val2));
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val3));
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val4));
}
for (; i + 15 < n; i += 16) {
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), _mm512_set1_ps(max));
_mm512_storeu_ps(y + i, val);
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val));
}
sum = (ggml_float)_mm512_reduce_add_ps(sum_v);
#elif defined(__AVX2__) && defined(__FMA__)
__m256 sum_v1 = _mm256_setzero_ps();
__m256 sum_v2 = _mm256_setzero_ps();
__m256 sum_v3 = _mm256_setzero_ps();
__m256 sum_v4 = _mm256_setzero_ps();
for (; i + 31 < n; i += 32) {
__m256 val1 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(max));
__m256 val2 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(max));
__m256 val3 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(max));
__m256 val4 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(max));
_mm256_storeu_ps(y + i + 0, val1);
_mm256_storeu_ps(y + i + 8, val2);
_mm256_storeu_ps(y + i + 16, val3);
_mm256_storeu_ps(y + i + 24, val4);
sum_v1 = _mm256_add_ps(sum_v1, ggml_v_expf(val1));
sum_v2 = _mm256_add_ps(sum_v2, ggml_v_expf(val2));
sum_v3 = _mm256_add_ps(sum_v3, ggml_v_expf(val3));
sum_v4 = _mm256_add_ps(sum_v4, ggml_v_expf(val4));
}
for (; i + 7 < n; i += 8) {
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), _mm256_set1_ps(max));
_mm256_storeu_ps(y + i, val);
sum_v1 = _mm256_add_ps(sum_v1, ggml_v_expf(val));
}
sum_v1 = _mm256_add_ps(sum_v1, sum_v2);
sum_v1 = _mm256_add_ps(sum_v1, sum_v3);
sum_v1 = _mm256_add_ps(sum_v1, sum_v4);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), _mm_set1_ps(max));
_mm_storeu_ps(y + i, val);
val = ggml_v_expf(val);
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
val = _mm_add_ss(val, _mm_movehdup_ps(val));
#else
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
val = _mm_add_ps(val, tmp);
tmp = _mm_movehl_ps(tmp, val);
val = _mm_add_ss(val, tmp);
#endif
sum += (ggml_float)_mm_cvtss_f32(val);
}
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
const int vlen = svcntw();
for (; i < n; i += vlen) {
const svbool_t pg = svwhilelt_b32_s32(i, n);
svfloat32_t val = svsub_f32_x(pg, svld1_f32(pg, x + i), svdup_n_f32_x(pg, max));
svst1_f32(pg, y + i, val);
sum += (ggml_float)svaddv_f32(pg, ggml_v_expf(pg, val));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
float32x4_t val = vsubq_f32(vld1q_f32(x + i), vdupq_n_f32(max));
vst1q_f32(y + i, val);
sum += (ggml_float)vaddvq_f32(ggml_v_expf(val));
}
#elif defined(__riscv_v_intrinsic)
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
for (int avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl);
__riscv_vse32_v_f32m2(&y[i], val, avl);
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(ggml_v_expf_m2(val, avl), vsum, avl);
}
sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
#endif
for (; i < n; ++i) {
float val = x[i] - max;
y[i] = val;
sum += (ggml_float)expf(val);
}
return sum = (ggml_float)logf(sum);
return (ggml_float)logf(sum);
}