ggml-cpu/vec: rewrite to SIMD with 4 element calc ggml_vec_log_soft_max_f32 for all platforms
This commit is contained in:
parent
ea8264950c
commit
be23f5f74a
|
|
@ -634,10 +634,103 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl
|
|||
|
||||
int i = 0;
|
||||
ggml_float sum = 0;
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
__m512 sum_v = _mm512_setzero_ps();
|
||||
for (; i + 63 < n; i += 64) {
|
||||
__m512 val1 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 0), _mm512_set1_ps(max));
|
||||
__m512 val2 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 16), _mm512_set1_ps(max));
|
||||
__m512 val3 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 32), _mm512_set1_ps(max));
|
||||
__m512 val4 = _mm512_sub_ps(_mm512_loadu_ps(x + i + 48), _mm512_set1_ps(max));
|
||||
_mm512_storeu_ps(y + i + 0, val1);
|
||||
_mm512_storeu_ps(y + i + 16, val2);
|
||||
_mm512_storeu_ps(y + i + 32, val3);
|
||||
_mm512_storeu_ps(y + i + 48, val4);
|
||||
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val1));
|
||||
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val2));
|
||||
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val3));
|
||||
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val4));
|
||||
}
|
||||
for (; i + 15 < n; i += 16) {
|
||||
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), _mm512_set1_ps(max));
|
||||
_mm512_storeu_ps(y + i, val);
|
||||
sum_v = _mm512_add_ps(sum_v, ggml_v_expf(val));
|
||||
}
|
||||
sum = (ggml_float)_mm512_reduce_add_ps(sum_v);
|
||||
#elif defined(__AVX2__) && defined(__FMA__)
|
||||
__m256 sum_v1 = _mm256_setzero_ps();
|
||||
__m256 sum_v2 = _mm256_setzero_ps();
|
||||
__m256 sum_v3 = _mm256_setzero_ps();
|
||||
__m256 sum_v4 = _mm256_setzero_ps();
|
||||
for (; i + 31 < n; i += 32) {
|
||||
__m256 val1 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 0), _mm256_set1_ps(max));
|
||||
__m256 val2 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 8), _mm256_set1_ps(max));
|
||||
__m256 val3 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 16), _mm256_set1_ps(max));
|
||||
__m256 val4 = _mm256_sub_ps(_mm256_loadu_ps(x + i + 24), _mm256_set1_ps(max));
|
||||
_mm256_storeu_ps(y + i + 0, val1);
|
||||
_mm256_storeu_ps(y + i + 8, val2);
|
||||
_mm256_storeu_ps(y + i + 16, val3);
|
||||
_mm256_storeu_ps(y + i + 24, val4);
|
||||
sum_v1 = _mm256_add_ps(sum_v1, ggml_v_expf(val1));
|
||||
sum_v2 = _mm256_add_ps(sum_v2, ggml_v_expf(val2));
|
||||
sum_v3 = _mm256_add_ps(sum_v3, ggml_v_expf(val3));
|
||||
sum_v4 = _mm256_add_ps(sum_v4, ggml_v_expf(val4));
|
||||
}
|
||||
for (; i + 7 < n; i += 8) {
|
||||
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), _mm256_set1_ps(max));
|
||||
_mm256_storeu_ps(y + i, val);
|
||||
sum_v1 = _mm256_add_ps(sum_v1, ggml_v_expf(val));
|
||||
}
|
||||
sum_v1 = _mm256_add_ps(sum_v1, sum_v2);
|
||||
sum_v1 = _mm256_add_ps(sum_v1, sum_v3);
|
||||
sum_v1 = _mm256_add_ps(sum_v1, sum_v4);
|
||||
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(sum_v1, 1), _mm256_castps256_ps128(sum_v1));
|
||||
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
|
||||
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
|
||||
sum += (ggml_float)_mm_cvtss_f32(val2);
|
||||
#elif defined(__SSE2__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), _mm_set1_ps(max));
|
||||
_mm_storeu_ps(y + i, val);
|
||||
val = ggml_v_expf(val);
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
|
||||
val = _mm_add_ss(val, _mm_movehdup_ps(val));
|
||||
#else
|
||||
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
val = _mm_add_ps(val, tmp);
|
||||
tmp = _mm_movehl_ps(tmp, val);
|
||||
val = _mm_add_ss(val, tmp);
|
||||
#endif
|
||||
sum += (ggml_float)_mm_cvtss_f32(val);
|
||||
}
|
||||
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
||||
const int vlen = svcntw();
|
||||
for (; i < n; i += vlen) {
|
||||
const svbool_t pg = svwhilelt_b32_s32(i, n);
|
||||
svfloat32_t val = svsub_f32_x(pg, svld1_f32(pg, x + i), svdup_n_f32_x(pg, max));
|
||||
svst1_f32(pg, y + i, val);
|
||||
sum += (ggml_float)svaddv_f32(pg, ggml_v_expf(pg, val));
|
||||
}
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
float32x4_t val = vsubq_f32(vld1q_f32(x + i), vdupq_n_f32(max));
|
||||
vst1q_f32(y + i, val);
|
||||
sum += (ggml_float)vaddvq_f32(ggml_v_expf(val));
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
|
||||
for (int avl; i < n; i += avl) {
|
||||
avl = __riscv_vsetvl_e32m2(n - i);
|
||||
vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl);
|
||||
__riscv_vse32_v_f32m2(&y[i], val, avl);
|
||||
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(ggml_v_expf_m2(val, avl), vsum, avl);
|
||||
}
|
||||
sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
float val = x[i] - max;
|
||||
y[i] = val;
|
||||
sum += (ggml_float)expf(val);
|
||||
}
|
||||
return sum = (ggml_float)logf(sum);
|
||||
return (ggml_float)logf(sum);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue