ggml : fix AMX and add batched support (#19925)

llama-perplexity -hf ggml-org/Qwen3-0.6B-GGUF:Q4_0 -f wikitext-2-raw/wiki.test.raw -c 2048 -b 2048 --chunks 2

before this commit:

```
perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1
perplexity: 2.31 seconds per pass - ETA 0.07 minutes
[1]17.3868,[2]22.2199,
Final estimate: PPL = 22.2199 +/- 1.59692

llama_perf_context_print:        load time =     878.56 ms
llama_perf_context_print: prompt eval time =    2037.82 ms /  4096 tokens (    0.50 ms per token,  2009.99 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    6403.17 ms /  4097 tokens
llama_perf_context_print:    graphs reused =          0
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                  845 =   318 +     224 +     302                |
llama_memory_breakdown_print: |   - CPU_REPACK         |                  288 =   288 +       0 +       0                |
llama_memory_breakdown_print: |   - AMX                |                   31 =    31 +       0 +       0                |
```

after this commit:

```
perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1
perplexity: 1.98 seconds per pass - ETA 0.05 minutes
[1]17.2005,[2]21.8220,
Final estimate: PPL = 21.8220 +/- 1.56485

llama_perf_context_print:        load time =     719.23 ms
llama_perf_context_print: prompt eval time =    1676.23 ms /  4096 tokens (    0.41 ms per token,  2443.58 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    4258.74 ms /  4097 tokens
llama_perf_context_print:    graphs reused =          0
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                  845 =   318 +     224 +     302                |
llama_memory_breakdown_print: |   - AMX                |                  319 =   319 +       0 +       0                |
```
(no more CPU_REPACK)

after this commit, disabling amx:

```
perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1
perplexity: 2.34 seconds per pass - ETA 0.07 minutes
[1]17.2005,[2]21.8220,
Final estimate: PPL = 21.8220 +/- 1.56485

llama_perf_context_print:        load time =     841.91 ms
llama_perf_context_print: prompt eval time =    2057.28 ms /  4096 tokens (    0.50 ms per token,  1990.98 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    6454.51 ms /  4097 tokens
llama_perf_context_print:    graphs reused =          0
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                  845 =   318 +     224 +     302                |
llama_memory_breakdown_print: |   - CPU_REPACK         |                  319 =   319 +       0 +       0                |
```
=> same perplexity.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2026-02-26 21:39:11 +01:00 committed by GitHub
parent 723c71064d
commit 4e76d24f28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 124 additions and 101 deletions

View File

@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
namespace ggml::cpu::amx {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};
if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
// src1 must be host buffer
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
// src1 must be float32
if (op->src[1]->type == GGML_TYPE_F32) {
return true;
}
if (op->op != GGML_OP_MUL_MAT) {
return false;
}
return false;
auto * src0 = op->src[0];
auto * src1 = op->src[1];
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
return false;
}
if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) {
return false;
}
if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) {
return false;
}
if (op->ne[0] % (TILE_N * 2)) {
return false;
}
int alignment;
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
alignment = TILE_K;
break;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ4_XS:
alignment = 256; // QK_K
break;
case GGML_TYPE_F16:
alignment = 16;
break;
default:
return false;
}
if (src0->ne[0] % alignment) {
return false;
}
if (src1->type != GGML_TYPE_F32) {
return false;
}
return true;
}
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {

View File

@ -1,4 +1,3 @@
#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Wpedantic"
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
@ -202,35 +201,27 @@ struct tile_config_t{
// advanced-matrix-extensions-intrinsics-functions.html
//
#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
void ggml_tile_config_init(void) {
static thread_local bool is_first_time = true;
inline void ggml_tile_config_init(void) {
static thread_local bool done = false;
if (!is_first_time) {
if (done) {
return;
}
static thread_local tile_config_t tc;
tile_config_t current_tc;
_tile_storeconfig(&current_tc);
alignas(64) tile_config_t tc = {};
tc.palette_id = 1;
tc.start_row = 0;
tc.rows[0] = 8; tc.colsb[0] = 64;
tc.rows[1] = 8; tc.colsb[1] = 64;
tc.rows[2] = 16; tc.colsb[2] = 32;
tc.rows[3] = 16; tc.colsb[3] = 32;
tc.rows[4] = 16; tc.colsb[4] = 64;
tc.rows[5] = 16; tc.colsb[5] = 64;
tc.rows[6] = 16; tc.colsb[6] = 64;
tc.rows[7] = 16; tc.colsb[7] = 64;
// load only when config changes
if (tc.palette_id == 0 || (memcmp(&current_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
tc.palette_id = 1;
tc.start_row = 0;
TC_CONFIG_TILE(TMM0, 8, 64);
TC_CONFIG_TILE(TMM1, 8, 64);
TC_CONFIG_TILE(TMM2, 16, 32);
TC_CONFIG_TILE(TMM3, 16, 32);
TC_CONFIG_TILE(TMM4, 16, 64);
TC_CONFIG_TILE(TMM5, 16, 64);
TC_CONFIG_TILE(TMM6, 16, 64);
TC_CONFIG_TILE(TMM7, 16, 64);
_tile_loadconfig(&tc);
}
is_first_time = false;
_tile_loadconfig(&tc);
done = true;
}
// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
@ -268,33 +259,6 @@ int get_row_size(int K) {
return row_size;
}
// vectorized dtype conversion
inline float FP16_TO_FP32(ggml_half val) {
__m256i v = _mm256_setr_epi16(
val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
__m512 o = _mm512_cvtph_ps(v);
return _mm512_cvtss_f32(o);
}
inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
__m256i v = _mm256_set1_epi16(val);
return _mm512_cvtph_ps(v);
}
// horizontal reduce
inline float _mm512_reduce_max_ps(const __m512 x) {
__m512 v = x;
__m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
v = _mm512_max_ps(v, v1);
v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
v = _mm512_max_ps(v, v1);
v1 = _mm512_shuffle_ps(v, v, 0x4E);
v = _mm512_max_ps(v, v1);
v1 = _mm512_shuffle_ps(v, v, 0xB1);
v = _mm512_max_ps(v, v1);
return _mm512_cvtss_f32(v);
}
// transpose utils
#define SHUFFLE_EPI32(a, b, mask) \
_mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \
tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \
K, (const float *)src1->data + mb_start * K, \
(const type *)src0->data + nb_start * K, \
(float *)dst->data + mb_start * ldc + nb_start, ldc);
K, (const float *)src1->data + src1_offset + mb_start * K, \
(const type *)src0->data + src0_offset + nb_start * K, \
(float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc)
// re-organize in the format {NB, KB, TILE_SIZE}:
@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
}
};
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
KB, (const char *)wdata + 0 * row_size_A, \
(const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
(float *) dst->data + 0 * N + nb_start, ldc)
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
KB, wdata_batch, \
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
(float *) dst->data + dst_offset + nb_start, ldc)
template <typename TA, typename TB, typename TC, int BLOCK_K,
typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
_tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
if (need_unpack) {
unpack_B<TB>(Tile1, B_blk0);
unpack_B<TB>(Tile1, B_blk1);
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
} else {
_tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
});
}
// ne2 is passed explicitly to help compiler optimize repeated calls
inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) {
const int64_t i2 = batch_idx % ne2;
const int64_t i3 = batch_idx / ne2;
return i3 * t->nb[3] + i2 * t->nb[2];
}
size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
struct ggml_tensor * src0 = dst->src[0];
@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
const int M = dst->ne[1];
const int K = src0->ne[0];
const int64_t n_batch = dst->ne[2] * dst->ne[3];
size_t desired_wsize = 0;
GGML_DISPATCH_QTYPES(TYPE, [&] {
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
desired_wsize = M * row_size_A;
desired_wsize = n_batch * M * row_size_A;
});
return desired_wsize;
@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
// src1: input in shape of {M, K}, float32
// dst: output in shape of {M, N}, float32
//
// the function performs: dst = src1 @ src0.T
// the function performs: dst = src1 @ src0.T for each batch
//
void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
struct ggml_tensor * src0 = dst->src[0];
@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
const int K = src0->ne[0];
const int ldc = dst->nb[1] / dst->nb[0];
const int64_t ne2 = dst->ne[2];
const int64_t n_batch = ne2 * dst->ne[3];
if (is_floating_type) {
constexpr int BLOCK_M = 4;
constexpr int BLOCK_N = 6;
const int MB = div_up(M, BLOCK_M);
const int NB = div_up(N, BLOCK_N);
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
for (int i = begin; i < end; ++i) {
int mb = i / NB;
int nb = i % NB;
int batch_idx = i / (MB * NB);
int remaining = i % (MB * NB);
int mb = remaining / NB;
int nb = remaining % NB;
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(BLOCK_M, M - mb_start);
@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
void * wdata = params->wdata;
//TODO: performance improvement: merge quant A
if (params->ith == 0) {
// if (params->ith == 0) {
GGML_DISPATCH_QTYPES(TYPE, [&] {
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
const size_t desired_wsize = M * row_size_A;
const size_t desired_wsize = n_batch * M * row_size_A;
if (params->wsize < desired_wsize) {
GGML_ABORT("insufficient work space size");
}
@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
const float * A_data = static_cast<const float *>(src1->data);
for (int m = 0; m < M; ++m) {
from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
}
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
for (int m = 0; m < M; ++m) {
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
}
}
});
});
}
// }
ggml_barrier(params->threadpool);
@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
constexpr int BLOCK_N = TILE_N * kTilesN;
const int NB = div_up(N, BLOCK_N);
parallel_for_ggml(params, NB, [&](int begin, int end) {
parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) {
GGML_DISPATCH_QTYPES(TYPE, [&] {
const int KB = K / blck_size;
const int TILE_SIZE = get_tile_size<type>();
const int row_size_A = KB * sizeof(vec_dot_type);
for (int i = begin; i < end; ++i) {
int nb = i;
int batch_idx = i / NB;
int nb = i % NB;
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A;
int nb_start = nb * BLOCK_N;
int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
const int MB = div_up(M, BLOCK_M);
const int NB = div_up(N, BLOCK_N);
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
// init tile config for each thread
ggml_tile_config_init();
@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
const int row_size_A = KB * sizeof(vec_dot_type);
for (int i = begin; i < end; ++i) {
int mb = i / NB;
int nb = i % NB;
int batch_idx = i / (MB * NB);
int remaining = i % (MB * NB);
int mb = remaining / NB;
int nb = remaining % NB;
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A;
int mb_start = mb * BLOCK_M;
int mb_size = std::min(BLOCK_M, M - mb_start);
@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
mb_size, nb_size, KB,
(const char *)wdata + mb_start * row_size_A,
(const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
(float *) dst->data + mb_start * N + nb_start, ldc);
wdata_batch + mb_start * row_size_A,
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
(float *) dst->data + dst_offset + mb_start * N + nb_start, ldc);
}
});
});