ggml : use a simple std::thread in AMX without OpenMP (#20074)

Disabling OpenMP generally provides better inference performance (at
least in my testing) but the loading becomes slightly slower.

Benchmark results for `convert_B_packed_format()`:

Before this commit:

         N      K |  No OpenMP     OpenMP |    Diff |  Speedup
    ------------------------------------------------------------
       512   2880 |    640.9us    263.5us |  -58.9% |    0.41x
      2880   4096 |     2.55ms    261.7us |  -89.8% |    0.10x
    201088   2880 |   256.44ms    21.61ms |  -91.6% |    0.08x
    ------------------------------------------------------------

    Total: 325.43ms vs 31.05ms

After:

         N      K |  No OpenMP     OpenMP |    Diff |  Speedup
    ------------------------------------------------------------
       512   2880 |     1.49ms    263.5us |  -82.3% |    0.18x
      2880   4096 |     1.55ms    261.7us |  -83.1% |    0.17x
    201088   2880 |    24.03ms    21.61ms |  -10.1% |    0.90x
    ------------------------------------------------------------

    Total: 78.97ms vs 31.05ms

Tested with unsloth/gpt-oss-20b-GGUF:Q4_K_M.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2026-03-04 11:57:09 +01:00 committed by GitHub
parent c99909dd0b
commit 66199c9f03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 34 additions and 10 deletions

View File

@ -9,6 +9,8 @@
#if defined(GGML_USE_OPENMP)
#include <omp.h>
#else
#include <thread>
#endif
#define TILE_M 16
@ -56,18 +58,40 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
}
template <typename func_t>
inline void parallel_for(int n, const func_t& f) {
inline void parallel_for(int n, const func_t & f) {
if (n <= 0) {
return;
}
#if defined(GGML_USE_OPENMP)
#pragma omp parallel
{
int nth = omp_get_num_threads();
int ith = omp_get_thread_num();
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
}
#pragma omp parallel
{
int nth = omp_get_num_threads();
int ith = omp_get_thread_num();
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
}
#else
f(0, n);
int nth = std::thread::hardware_concurrency();
if (nth <= 1) {
f(0, n);
return;
}
if (nth > n) {
nth = n;
}
std::vector<std::thread> threads;
threads.reserve(nth);
for (int ith = 0; ith < nth; ++ith) {
threads.emplace_back([&f, n, ith, nth] {
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
});
}
for (auto & t : threads) {
t.join();
}
#endif
}