This commit is contained in:
Herman Semenoff 2026-01-03 00:19:50 +08:00 committed by GitHub
commit 2c8007993f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 180 additions and 16 deletions

View File

@ -454,27 +454,61 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
int i = 0;
ggml_float sum = 0;
// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 sum_v = _mm512_setzero_ps();
for (; i + 63 < n; i += 64) {
__m512 val1 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(mean));
__m512 val2 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(mean));
__m512 val3 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(mean));
__m512 val4 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(mean));
_mm512_storeu_ps(y + i + 0, val1);
_mm512_storeu_ps(y + i + 16, val2);
_mm512_storeu_ps(y + i + 32, val3);
_mm512_storeu_ps(y + i + 48, val4);
sum_v = _mm512_fmadd_ps(val1, val1, sum_v);
sum_v = _mm512_fmadd_ps(val2, val2, sum_v);
sum_v = _mm512_fmadd_ps(val3, val3, sum_v);
sum_v = _mm512_fmadd_ps(val4, val4, sum_v);
}
for (; i + 15 < n; i += 16) {
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
_mm512_set1_ps(mean));
_mm512_storeu_ps(y + i, val);
sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
sum_v = _mm512_fmadd_ps(val, val, sum_v);
}
sum += (ggml_float)_mm512_reduce_add_ps(sum_v);
#elif defined(__AVX2__) && defined(__FMA__)
__m256 sum_v1 = _mm256_setzero_ps();
__m256 sum_v2 = _mm256_setzero_ps();
__m256 sum_v3 = _mm256_setzero_ps();
__m256 sum_v4 = _mm256_setzero_ps();
for (; i + 31 < n; i += 32) {
__m256 val1 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(mean));
__m256 val2 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(mean));
__m256 val3 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(mean));
__m256 val4 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(mean));
_mm256_storeu_ps(y + i + 0, val1);
_mm256_storeu_ps(y + i + 8, val2);
_mm256_storeu_ps(y + i + 16, val3);
_mm256_storeu_ps(y + i + 24, val4);
sum_v1 = _mm256_fmadd_ps(val1, val1, sum_v1);
sum_v2 = _mm256_fmadd_ps(val2, val2, sum_v2);
sum_v3 = _mm256_fmadd_ps(val3, val3, sum_v3);
sum_v4 = _mm256_fmadd_ps(val4, val4, sum_v4);
}
for (; i + 7 < n; i += 8) {
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
_mm256_set1_ps(mean));
_mm256_storeu_ps(y + i, val);
val = _mm256_mul_ps(val,val);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
_mm256_castps256_ps128(val));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
sum_v1 = _mm256_fmadd_ps(val, val, sum_v1);
}
sum_v1 = _mm256_add_ps(sum_v1, sum_v2);
sum_v1 = _mm256_add_ps(sum_v1, sum_v3);
sum_v1 = _mm256_add_ps(sum_v1, sum_v4);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
@ -531,23 +565,60 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
int i = 0;
ggml_float sum = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 sum_v = _mm512_setzero_ps();
for (; i + 63 < n; i += 64) {
__m512 val1 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(max)));
__m512 val2 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(max)));
__m512 val3 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(max)));
__m512 val4 = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(max)));
_mm512_storeu_ps(y + i + 0, val1);
_mm512_storeu_ps(y + i + 16, val2);
_mm512_storeu_ps(y + i + 32, val3);
_mm512_storeu_ps(y + i + 48, val4);
sum_v = _mm512_add_ps(sum_v, val1);
sum_v = _mm512_add_ps(sum_v, val2);
sum_v = _mm512_add_ps(sum_v, val3);
sum_v = _mm512_add_ps(sum_v, val4);
}
for (; i + 15 < n; i += 16) {
__m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
_mm512_set1_ps(max)));
_mm512_storeu_ps(y + i, val);
sum += (ggml_float)_mm512_reduce_add_ps(val);
sum_v = _mm512_add_ps(sum_v, val);
}
sum += (ggml_float)_mm512_reduce_add_ps(sum_v);
#elif defined(__AVX2__) && defined(__FMA__)
__m256 sum_v1 = _mm256_setzero_ps();
__m256 sum_v2 = _mm256_setzero_ps();
__m256 sum_v3 = _mm256_setzero_ps();
__m256 sum_v4 = _mm256_setzero_ps();
for (; i + 31 < n; i += 32) {
__m256 val1 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(max)));
__m256 val2 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(max)));
__m256 val3 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(max)));
__m256 val4 = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(max)));
_mm256_storeu_ps(y + i + 0, val1);
_mm256_storeu_ps(y + i + 8, val2);
_mm256_storeu_ps(y + i + 16, val3);
_mm256_storeu_ps(y + i + 24, val4);
sum_v1 = _mm256_add_ps(sum_v1, val1);
sum_v2 = _mm256_add_ps(sum_v2, val2);
sum_v3 = _mm256_add_ps(sum_v3, val3);
sum_v4 = _mm256_add_ps(sum_v4, val4);
}
for (; i + 7 < n; i += 8) {
__m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
_mm256_set1_ps(max)));
_mm256_storeu_ps(y + i, val);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
_mm256_castps256_ps128(val));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
sum_v1 = _mm256_add_ps(sum_v1, val);
}
sum_v1 = _mm256_add_ps(sum_v1, sum_v2);
sum_v1 = _mm256_add_ps(sum_v1, sum_v3);
sum_v1 = _mm256_add_ps(sum_v1, sum_v4);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
__m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
@ -603,10 +674,103 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl
int i = 0;
ggml_float sum = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 sum_v = _mm512_setzero_ps();
for (; i + 63 < n; i += 64) {
__m512 val1 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(max));
__m512 val2 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(max));
__m512 val3 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(max));
__m512 val4 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(max));
_mm512_storeu_ps(y + i + 0, val1);
_mm512_storeu_ps(y + i + 16, val2);
_mm512_storeu_ps(y + i + 32, val3);
_mm512_storeu_ps(y + i + 48, val4);
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val1));
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val2));
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val3));
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val4));
}
for (; i + 15 < n; i += 16) {
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), _mm512_set1_ps(max));
_mm512_storeu_ps(y + i, val);
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val));
}
sum = (ggml_float)_mm512_reduce_add_ps(sum_v);
#elif defined(__AVX2__) && defined(__FMA__)
__m256 sum_v1 = _mm256_setzero_ps();
__m256 sum_v2 = _mm256_setzero_ps();
__m256 sum_v3 = _mm256_setzero_ps();
__m256 sum_v4 = _mm256_setzero_ps();
for (; i + 31 < n; i += 32) {
__m256 val1 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(max));
__m256 val2 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(max));
__m256 val3 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(max));
__m256 val4 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(max));
_mm256_storeu_ps(y + i + 0, val1);
_mm256_storeu_ps(y + i + 8, val2);
_mm256_storeu_ps(y + i + 16, val3);
_mm256_storeu_ps(y + i + 24, val4);
sum_v1 = _mm256_add_ps(sum_v1, ggml_v_expf(val1));
sum_v2 = _mm256_add_ps(sum_v2, ggml_v_expf(val2));
sum_v3 = _mm256_add_ps(sum_v3, ggml_v_expf(val3));
sum_v4 = _mm256_add_ps(sum_v4, ggml_v_expf(val4));
}
for (; i + 7 < n; i += 8) {
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), _mm256_set1_ps(max));
_mm256_storeu_ps(y + i, val);
sum_v1 = _mm256_add_ps(sum_v1, ggml_v_expf(val));
}
sum_v1 = _mm256_add_ps(sum_v1, sum_v2);
sum_v1 = _mm256_add_ps(sum_v1, sum_v3);
sum_v1 = _mm256_add_ps(sum_v1, sum_v4);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), _mm_set1_ps(max));
_mm_storeu_ps(y + i, val);
val = ggml_v_expf(val);
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
val = _mm_add_ss(val, _mm_movehdup_ps(val));
#else
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
val = _mm_add_ps(val, tmp);
tmp = _mm_movehl_ps(tmp, val);
val = _mm_add_ss(val, tmp);
#endif
sum += (ggml_float)_mm_cvtss_f32(val);
}
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
const int vlen = svcntw();
for (; i < n; i += vlen) {
const svbool_t pg = svwhilelt_b32_s32(i, n);
svfloat32_t val = svsub_f32_x(pg, svld1_f32(pg, x + i), svdup_n_f32_x(pg, max));
svst1_f32(pg, y + i, val);
sum += (ggml_float)svaddv_f32(pg, ggml_v_expf(pg, val));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
float32x4_t val = vsubq_f32(vld1q_f32(x + i), vdupq_n_f32(max));
vst1q_f32(y + i, val);
sum += (ggml_float)vaddvq_f32(ggml_v_expf(val));
}
#elif defined(__riscv_v_intrinsic)
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
for (int avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl);
__riscv_vse32_v_f32m2(&y[i], val, avl);
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(ggml_v_expf_m2(val, avl), vsum, avl);
}
sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
#endif
for (; i < n; ++i) {
float val = x[i] - max;
y[i] = val;
sum += (ggml_float)expf(val);
}
return sum = (ggml_float)logf(sum);
return (ggml_float)logf(sum);
}