diff --git a/ggml/src/ggml-cpu/amx/common.h b/ggml/src/ggml-cpu/amx/common.h index f392e89851..26a6ec1a2d 100644 --- a/ggml/src/ggml-cpu/amx/common.h +++ b/ggml/src/ggml-cpu/amx/common.h @@ -9,6 +9,8 @@ #if defined(GGML_USE_OPENMP) #include +#else +#include #endif #define TILE_M 16 @@ -56,18 +58,40 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { } template -inline void parallel_for(int n, const func_t& f) { +inline void parallel_for(int n, const func_t & f) { + if (n <= 0) { + return; + } #if defined(GGML_USE_OPENMP) -#pragma omp parallel -{ - int nth = omp_get_num_threads(); - int ith = omp_get_thread_num(); - int tbegin, tend; - balance211(n, nth, ith, tbegin, tend); - f(tbegin, tend); -} + #pragma omp parallel + { + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); + } #else - f(0, n); + int nth = std::thread::hardware_concurrency(); + if (nth <= 1) { + f(0, n); + return; + } + if (nth > n) { + nth = n; + } + std::vector threads; + threads.reserve(nth); + for (int ith = 0; ith < nth; ++ith) { + threads.emplace_back([&f, n, ith, nth] { + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); + }); + } + for (auto & t : threads) { + t.join(); + } #endif }