diff --git a/common/common.cpp b/common/common.cpp index 3aa396127c..4660c131a3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1369,6 +1369,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; mparams.no_host = params.no_host; + mparams.repack_n_threads = params.cpuparams.n_threads; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 4f3b99c8d0..5d9f7b4d82 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -52,6 +52,10 @@ extern "C" { GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + // parallel repack threads + GGML_BACKEND_API void ggml_cpu_set_repack_n_threads(int n_threads); + GGML_BACKEND_API int ggml_cpu_get_repack_n_threads(void); + GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params); GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool); GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool); diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 0e8dd0ae05..80f083b003 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -517,6 +517,7 @@ static __m256 __lasx_xvreplfr2vr_s(const float val) { // TODO: move to ggml-threading void ggml_barrier(struct ggml_threadpool * tp); +void ggml_cpu_set_numa_thread_affinity(int thread_n); void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value); int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b1de2ae871..1bb55f3001 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2088,7 +2088,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm // Android's libc implementation "bionic" does not support setting affinity #if defined(__gnu_linux__) -static void set_numa_thread_affinity(int thread_n) { +void ggml_cpu_set_numa_thread_affinity(int thread_n) { if (!ggml_is_numa()) { return; } @@ -2156,7 +2156,7 @@ static void clear_numa_thread_affinity(void) { #else // TODO: Windows etc. // (the linux implementation may also work on BSD, someone should test) -static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } +void ggml_cpu_set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } static void clear_numa_thread_affinity(void) {} #endif @@ -2926,7 +2926,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const struct ggml_cgraph * cgraph = tp->cgraph; const struct ggml_cplan * cplan = tp->cplan; - set_numa_thread_affinity(state->ith); + ggml_cpu_set_numa_thread_affinity(state->ith); struct ggml_compute_params params = { /*.ith =*/ state->ith, diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 24e8ab4618..7b0c22c1a4 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -16,6 +16,12 @@ #include #include // for GGML_ASSERT +static int g_repack_n_threads = 1; + +#if defined(GGML_USE_OPENMP) +#include +#endif + #include "repack.h" #if defined(__GNUC__) @@ -48,6 +54,19 @@ static inline int nearest_int(float fval) { extern "C" { +#if defined(GGML_USE_OPENMP) +void ggml_cpu_set_repack_n_threads(int n_threads) { + g_repack_n_threads = n_threads; +} + +int ggml_cpu_get_repack_n_threads(void) { + return g_repack_n_threads; +} +#else +void ggml_cpu_set_repack_n_threads(int n_threads) {} +int ggml_cpu_get_repack_n_threads(void) { return 0; } +#endif + void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -2140,11 +2159,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block GGML_ASSERT(interleave_block == 4 || interleave_block == 8); constexpr int nrows_interleaved = 4; - block_q4_0x4 * dst = (block_q4_0x4 *)t->data; - const block_q4_0 * src = (const block_q4_0 *)data; - block_q4_0 dst_tmp[4]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; + block_q4_0x4 * dst_base = (block_q4_0x4 *)t->data; + const block_q4_0 * src_base = (const block_q4_0 *)data; + const int nrow = ggml_nrows(t); + const int nblocks = t->ne[0] / QK4_0; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); @@ -2152,14 +2170,23 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block return -1; } - for (int b = 0; b < nrow; b += nrows_interleaved) { + const int n_row_groups = nrow / nrows_interleaved; + +#ifdef GGML_USE_OPENMP +#pragma omp for schedule(static) +#endif + for (int bg = 0; bg < n_row_groups; bg++) { + const int b = bg * nrows_interleaved; + const block_q4_0 * src = src_base + b * nblocks; + block_q4_0x4 * dst = dst_base + bg * nblocks; + block_q4_0 dst_tmp[4]; + for (int64_t x = 0; x < nblocks; x++) { for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q4_0x4(dst_tmp, interleave_block); + dst[x] = make_block_q4_0x4(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; } return 0; @@ -2171,11 +2198,10 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_ASSERT(interleave_block == 8 || interleave_block == 4); constexpr int nrows_interleaved = 8; - block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; - const block_q4_K * src = (const block_q4_K*) data; - block_q4_K dst_tmp[8]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + block_q4_Kx8 * dst_base = (block_q4_Kx8*)t->data; + const block_q4_K * src_base = (const block_q4_K*) data; + const int nrow = ggml_nrows(t); + const int nblocks = t->ne[0] / QK_K; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); @@ -2183,14 +2209,23 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block return -1; } - for (int b = 0; b < nrow; b += nrows_interleaved) { + const int n_row_groups = nrow / nrows_interleaved; + +#ifdef GGML_USE_OPENMP +#pragma omp for schedule(static) +#endif + for (int bg = 0; bg < n_row_groups; bg++) { + const int b = bg * nrows_interleaved; + const block_q4_K * src = src_base + b * nblocks; + block_q4_Kx8 * dst = dst_base + bg * nblocks; + block_q4_K dst_tmp[8]; + for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { + for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block); + dst[x] = make_block_q4_Kx8(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; } return 0; @@ -2202,11 +2237,10 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_ASSERT(interleave_block == 8); constexpr int nrows_interleaved = 8; - block_q2_Kx8 * dst = (block_q2_Kx8*)t->data; - const block_q2_K * src = (const block_q2_K*) data; - block_q2_K dst_tmp[8]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + block_q2_Kx8 * dst_base = (block_q2_Kx8*)t->data; + const block_q2_K * src_base = (const block_q2_K*) data; + const int nrow = ggml_nrows(t); + const int nblocks = t->ne[0] / QK_K; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); @@ -2214,14 +2248,23 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block return -1; } - for (int b = 0; b < nrow; b += nrows_interleaved) { + const int n_row_groups = nrow / nrows_interleaved; + +#ifdef GGML_USE_OPENMP +#pragma omp for schedule(static) +#endif + for (int bg = 0; bg < n_row_groups; bg++) { + const int b = bg * nrows_interleaved; + const block_q2_K * src = src_base + b * nblocks; + block_q2_Kx8 * dst = dst_base + bg * nblocks; + block_q2_K dst_tmp[8]; + for (int64_t x = 0; x < nblocks; x++) { for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block); + dst[x] = make_block_q2_Kx8(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; } return 0; @@ -2294,11 +2337,10 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block GGML_ASSERT(interleave_block == 8); constexpr int nrows_interleaved = 8; - block_q4_0x8 * dst = (block_q4_0x8*)t->data; - const block_q4_0 * src = (const block_q4_0*) data; - block_q4_0 dst_tmp[8]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; + block_q4_0x8 * dst_base = (block_q4_0x8*)t->data; + const block_q4_0 * src_base = (const block_q4_0*) data; + const int nrow = ggml_nrows(t); + const int nblocks = t->ne[0] / QK4_0; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); @@ -2306,14 +2348,23 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block return -1; } - for (int b = 0; b < nrow; b += nrows_interleaved) { + const int n_row_groups = nrow / nrows_interleaved; + +#ifdef GGML_USE_OPENMP +#pragma omp for schedule(static) +#endif + for (int bg = 0; bg < n_row_groups; bg++) { + const int b = bg * nrows_interleaved; + const block_q4_0 * src = src_base + b * nblocks; + block_q4_0x8 * dst = dst_base + bg * nblocks; + block_q4_0 dst_tmp[8]; + for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { + for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q4_0x8(dst_tmp, interleave_block); + dst[x] = make_block_q4_0x8(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; } return 0; @@ -2391,14 +2442,12 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); GGML_ASSERT(interleave_block == 4); - const block_iq4_nl * src = (const block_iq4_nl *)data; - block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data; + const block_iq4_nl * src_base = (const block_iq4_nl *)data; + block_iq4_nlx4 * dst_base = (block_iq4_nlx4 *)t->data; - block_iq4_nl dst_tmp[4]; - - int nrow = ggml_nrows(t); - int nrows_interleaved = 4; - int nblocks = t->ne[0] / QK4_NL; + const int nrow = ggml_nrows(t); + const int nrows_interleaved = 4; + const int nblocks = t->ne[0] / QK4_NL; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); @@ -2406,14 +2455,23 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b return -1; } - for (int b = 0; b < nrow; b += nrows_interleaved) { + const int n_row_groups = nrow / nrows_interleaved; + +#ifdef GGML_USE_OPENMP +#pragma omp for schedule(static) +#endif + for (int bg = 0; bg < n_row_groups; bg++) { + const int b = bg * nrows_interleaved; + const block_iq4_nl * src = src_base + b * nblocks; + block_iq4_nlx4 * dst = dst_base + bg * nblocks; + block_iq4_nl dst_tmp[4]; + for (int64_t x = 0; x < nblocks; x++) { for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block); + dst[x] = make_block_iq4_nlx4(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; } return 0; @@ -2448,14 +2506,12 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); GGML_ASSERT(interleave_block == 8); - const block_iq4_nl * src = (const block_iq4_nl *)data; - block_iq4_nlx8 * dst = ( block_iq4_nlx8 *)t->data; + const block_iq4_nl * src_base = (const block_iq4_nl *)data; + block_iq4_nlx8 * dst_base = (block_iq4_nlx8 *)t->data; - block_iq4_nl dst_tmp[8]; - - int nrow = ggml_nrows(t); - int nrows_interleaved = 8; - int nblocks = t->ne[0] / QK4_NL; + const int nrow = ggml_nrows(t); + const int nrows_interleaved = 8; + const int nblocks = t->ne[0] / QK4_NL; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); @@ -2463,14 +2519,23 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b return -1; } - for (int b = 0; b < nrow; b += nrows_interleaved) { + const int n_row_groups = nrow / nrows_interleaved; + +#ifdef GGML_USE_OPENMP +#pragma omp for schedule(static) +#endif + for (int bg = 0; bg < n_row_groups; bg++) { + const int b = bg * nrows_interleaved; + const block_iq4_nl * src = src_base + b * nblocks; + block_iq4_nlx8 * dst = dst_base + bg * nblocks; + block_iq4_nl dst_tmp[8]; + for (int64_t x = 0; x < nblocks; x++) { for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block); + dst[x] = make_block_iq4_nlx8(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; } return 0; @@ -3021,9 +3086,29 @@ template name, ggml_type_name(t->type), (int) NB_COLS, (int) INTER_SIZE); - return ggml::cpu::repack::repack(t, data, data_size); +#ifdef GGML_USE_OPENMP + int n_threads = ggml_cpu_get_repack_n_threads(); + GGML_ASSERT(n_threads >= 0); + if (n_threads == 0) { + n_threads = omp_get_max_threads(); + } + if (n_threads > 1) { + #pragma omp parallel num_threads(n_threads) + { + ggml_cpu_set_numa_thread_affinity(omp_get_thread_num()); + int r = ggml::cpu::repack::repack(t, data, data_size); + #pragma omp master + ret = r; + } + } +#endif + if (ret == -1) { + ret = ggml::cpu::repack::repack(t, data, data_size); + } + return ret; } }; diff --git a/include/llama.h b/include/llama.h index bf4e28a8be..26afeb44aa 100644 --- a/include/llama.h +++ b/include/llama.h @@ -314,6 +314,7 @@ extern "C" { bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) bool no_host; // bypass host buffer allowing extra buffers to be used + int32_t repack_n_threads; // number of threads to use for repacking bool no_alloc; // only load metadata and simulate memory allocations }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 72490a89b5..d6c8629a96 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8130,6 +8130,7 @@ llama_model_params llama_model_default_params() { /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, /*.no_host =*/ false, + /*.repack_n_threads =*/ 0, /*.no_alloc =*/ false, }; diff --git a/src/llama.cpp b/src/llama.cpp index 6da90d6f1f..2118ca1387 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -826,6 +826,7 @@ int64_t llama_time_us(void) { // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { + ggml_cpu_set_repack_n_threads(params.repack_n_threads); // loading time will be recalculated after the first eval, so // we take page faults deferred by mmap() into consideration model.t_load_us = 0; diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 7da6c3957c..0176103848 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -1082,12 +1082,13 @@ struct cmd_params_instance { if (!devices.empty()) { mparams.devices = const_cast(devices.data()); } - mparams.split_mode = split_mode; - mparams.main_gpu = main_gpu; - mparams.tensor_split = tensor_split.data(); - mparams.use_mmap = use_mmap; + mparams.split_mode = split_mode; + mparams.main_gpu = main_gpu; + mparams.tensor_split = tensor_split.data(); + mparams.use_mmap = use_mmap; mparams.use_direct_io = use_direct_io; - mparams.no_host = no_host; + mparams.no_host = no_host; + mparams.repack_n_threads = n_threads; if (n_cpu_moe <= 0) { if (tensor_buft_overrides.empty()) {