diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 895a571375..9baf3e025e 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ namespace ggml::cpu::amx { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - // handle only 2d gemm for now - auto is_contiguous_2d = [](const struct ggml_tensor * t) { - return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; - }; - - if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous - is_contiguous_2d(op->src[1]) && // src1 must be contiguous - op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && - op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) - op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x - (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { - // src1 must be host buffer - if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { - return false; - } - // src1 must be float32 - if (op->src[1]->type == GGML_TYPE_F32) { - return true; - } + if (op->op != GGML_OP_MUL_MAT) { + return false; } - return false; + auto * src0 = op->src[0]; + auto * src1 = op->src[1]; + + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + return false; + } + if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) { + return false; + } + if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) { + return false; + } + if (op->ne[0] % (TILE_N * 2)) { + return false; + } + int alignment; + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + alignment = TILE_K; + break; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_XS: + alignment = 256; // QK_K + break; + case GGML_TYPE_F16: + alignment = 16; + break; + default: + return false; + } + if (src0->ne[0] % alignment) { + return false; + } + if (src1->type != GGML_TYPE_F32) { + return false; + } + return true; } ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index 47c61b8816..b5aca76633 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -1,4 +1,3 @@ - #if defined(__GNUC__) #pragma GCC diagnostic ignored "-Wpedantic" #pragma GCC diagnostic ignored "-Wunused-local-typedefs" @@ -202,35 +201,27 @@ struct tile_config_t{ // advanced-matrix-extensions-intrinsics-functions.html // -#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb -void ggml_tile_config_init(void) { - static thread_local bool is_first_time = true; +inline void ggml_tile_config_init(void) { + static thread_local bool done = false; - if (!is_first_time) { + if (done) { return; } - static thread_local tile_config_t tc; - tile_config_t current_tc; - _tile_storeconfig(¤t_tc); + alignas(64) tile_config_t tc = {}; + tc.palette_id = 1; + tc.start_row = 0; + tc.rows[0] = 8; tc.colsb[0] = 64; + tc.rows[1] = 8; tc.colsb[1] = 64; + tc.rows[2] = 16; tc.colsb[2] = 32; + tc.rows[3] = 16; tc.colsb[3] = 32; + tc.rows[4] = 16; tc.colsb[4] = 64; + tc.rows[5] = 16; tc.colsb[5] = 64; + tc.rows[6] = 16; tc.colsb[6] = 64; + tc.rows[7] = 16; tc.colsb[7] = 64; - // load only when config changes - if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && - memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { - tc.palette_id = 1; - tc.start_row = 0; - TC_CONFIG_TILE(TMM0, 8, 64); - TC_CONFIG_TILE(TMM1, 8, 64); - TC_CONFIG_TILE(TMM2, 16, 32); - TC_CONFIG_TILE(TMM3, 16, 32); - TC_CONFIG_TILE(TMM4, 16, 64); - TC_CONFIG_TILE(TMM5, 16, 64); - TC_CONFIG_TILE(TMM6, 16, 64); - TC_CONFIG_TILE(TMM7, 16, 64); - _tile_loadconfig(&tc); - } - - is_first_time = false; + _tile_loadconfig(&tc); + done = true; } // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. @@ -268,33 +259,6 @@ int get_row_size(int K) { return row_size; } -// vectorized dtype conversion -inline float FP16_TO_FP32(ggml_half val) { - __m256i v = _mm256_setr_epi16( - val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); - __m512 o = _mm512_cvtph_ps(v); - return _mm512_cvtss_f32(o); -} - -inline __m512 FP16_TO_FP32_VEC(ggml_half val) { - __m256i v = _mm256_set1_epi16(val); - return _mm512_cvtph_ps(v); -} - -// horizontal reduce -inline float _mm512_reduce_max_ps(const __m512 x) { - __m512 v = x; - __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_f32x4(v, v, 0xB1); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_ps(v, v, 0x4E); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_ps(v, v, 0xB1); - v = _mm512_max_ps(v, v1); - return _mm512_cvtss_f32(v); -} - // transpose utils #define SHUFFLE_EPI32(a, b, mask) \ _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) @@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ tinygemm_kernel_avx::apply( \ - K, (const float *)src1->data + mb_start * K, \ - (const type *)src0->data + nb_start * K, \ - (float *)dst->data + mb_start * ldc + nb_start, ldc); + K, (const float *)src1->data + src1_offset + mb_start * K, \ + (const type *)src0->data + src0_offset + nb_start * K, \ + (float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc) // re-organize in the format {NB, KB, TILE_SIZE}: @@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni::apply( \ - KB, (const char *)wdata + 0 * row_size_A, \ - (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ - (float *) dst->data + 0 * N + nb_start, ldc) +#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + KB, wdata_batch, \ + (const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ + (float *) dst->data + dst_offset + nb_start, ldc) template ::value, int>::type = 0> @@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); if (need_unpack) { - unpack_B(Tile1, B_blk0); + unpack_B(Tile1, B_blk1); _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); } else { _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); @@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d }); } +// ne2 is passed explicitly to help compiler optimize repeated calls +inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) { + const int64_t i2 = batch_idx % ne2; + const int64_t i3 = batch_idx / ne2; + return i3 * t->nb[3] + i2 * t->nb[2]; +} + size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { struct ggml_tensor * src0 = dst->src[0]; @@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { const int M = dst->ne[1]; const int K = src0->ne[0]; + const int64_t n_batch = dst->ne[2] * dst->ne[3]; size_t desired_wsize = 0; GGML_DISPATCH_QTYPES(TYPE, [&] { const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); - desired_wsize = M * row_size_A; + desired_wsize = n_batch * M * row_size_A; }); return desired_wsize; @@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { // src1: input in shape of {M, K}, float32 // dst: output in shape of {M, N}, float32 // -// the function performs: dst = src1 @ src0.T +// the function performs: dst = src1 @ src0.T for each batch // void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) { struct ggml_tensor * src0 = dst->src[0]; @@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int K = src0->ne[0]; const int ldc = dst->nb[1] / dst->nb[0]; + const int64_t ne2 = dst->ne[2]; + const int64_t n_batch = ne2 * dst->ne[3]; + if (is_floating_type) { constexpr int BLOCK_M = 4; constexpr int BLOCK_N = 6; const int MB = div_up(M, BLOCK_M); const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) { GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { for (int i = begin; i < end; ++i) { - int mb = i / NB; - int nb = i % NB; + int batch_idx = i / (MB * NB); + int remaining = i % (MB * NB); + int mb = remaining / NB; + int nb = remaining % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); int mb_start = mb * BLOCK_M; int mb_size = std::min(BLOCK_M, M - mb_start); @@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te void * wdata = params->wdata; //TODO: performance improvement: merge quant A - if (params->ith == 0) { + // if (params->ith == 0) { GGML_DISPATCH_QTYPES(TYPE, [&] { const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); - const size_t desired_wsize = M * row_size_A; + const size_t desired_wsize = n_batch * M * row_size_A; if (params->wsize < desired_wsize) { GGML_ABORT("insufficient work space size"); } @@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); - const float * A_data = static_cast(src1->data); - for (int m = 0; m < M; ++m) { - from_float(A_data + m * K, (char *)wdata + m * row_size_A, K); - } + parallel_for_ggml(params, n_batch, [&](int begin, int end) { + for (int batch_idx = begin; batch_idx < end; ++batch_idx) { + int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); + const float * A_data = (const float *)((const char *)src1->data + src1_offset); + char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A; + + for (int m = 0; m < M; ++m) { + from_float(A_data + m * K, wdata_batch + m * row_size_A, K); + } + } + }); }); - } + // } ggml_barrier(params->threadpool); @@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te constexpr int BLOCK_N = TILE_N * kTilesN; const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) { GGML_DISPATCH_QTYPES(TYPE, [&] { const int KB = K / blck_size; const int TILE_SIZE = get_tile_size(); const int row_size_A = KB * sizeof(vec_dot_type); for (int i = begin; i < end; ++i) { - int nb = i; + int batch_idx = i / NB; + int nb = i % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); + const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A; + int nb_start = nb * BLOCK_N; int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 @@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int MB = div_up(M, BLOCK_M); const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) { // init tile config for each thread ggml_tile_config_init(); @@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int row_size_A = KB * sizeof(vec_dot_type); for (int i = begin; i < end; ++i) { - int mb = i / NB; - int nb = i % NB; + int batch_idx = i / (MB * NB); + int remaining = i % (MB * NB); + int mb = remaining / NB; + int nb = remaining % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); + const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A; int mb_start = mb * BLOCK_M; int mb_size = std::min(BLOCK_M, M - mb_start); @@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te tinygemm_kernel_amx( mb_size, nb_size, KB, - (const char *)wdata + mb_start * row_size_A, - (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), - (float *) dst->data + mb_start * N + nb_start, ldc); + wdata_batch + mb_start * row_size_A, + (const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), + (float *) dst->data + dst_offset + mb_start * N + nb_start, ldc); } }); });