From 267ba5a1d957158316a4fc85b1f0cf316d9a5233 Mon Sep 17 00:00:00 2001 From: abhijain1204fujitsu <139222713+abhijain1204fujitsu@users.noreply.github.com> Date: Mon, 16 Feb 2026 12:08:43 +0530 Subject: [PATCH 01/10] ggml: aarch64: Implement SVE in Gemm q4_k 8x8 q8_k Kernel (#19132) * Updated repack.cpp * Updated repack.cpp * Updated repack.cpp * Added if condition to support only vector length 256. * Changed the format removed comments and duplicate variable * If SVE 256 not present then was using generic function to compute, hence slowing the performance. So added code if SVE 256 is not present then use NEON code. * Code format change suggestion --------- Co-authored-by: Vithule, Prashant --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 310 ++++++++++++++++++++++++++ 1 file changed, 310 insertions(+) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index fd05c609f7..3a3b32efb2 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3226,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + constexpr int q8_k_blocklen = 4; + const svuint8_t m4b_1 = svdup_n_u8(0x0f); + // 8 accumulators: 2 row pairs × 4 col pairs + svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67; + uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 }; + svbool_t pg = svptrue_pat_b32(SV_VL8); + svuint32_t idx = svld1(pg, idx_arr); + + static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7}; + svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data); + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + acc_f32_01 = svdup_n_f32(0); + acc_f32_23 = svdup_n_f32(0); + acc_f32_45 = svdup_n_f32(0); + acc_f32_67 = svdup_n_f32(0); + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + + int32_t bsums_arr32[4][8]; + + for (int q8_row = 0; q8_row < 4; q8_row++) { + int16x8_t v16 = bsums[q8_row]; + + // low 4 + int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][0], v32_lo); + + // high 4 + int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][4], v32_hi); + } + + svint32_t sb_acc_0 = svdup_n_s32(0); + svint32_t sb_acc_2 = svdup_n_s32(0); + + svint32_t acc_00 = svdup_n_s32(0); + svint32_t acc_11 = svdup_n_s32(0); + svint32_t acc_22 = svdup_n_s32(0); + svint32_t acc_33 = svdup_n_s32(0); + svint32_t acc_44 = svdup_n_s32(0); + svint32_t acc_55 = svdup_n_s32(0); + svint32_t acc_66 = svdup_n_s32(0); + svint32_t acc_77 = svdup_n_s32(0); + + svint32_t bias_acc_00 = svdup_n_s32(0); + svint32_t bias_acc_22 = svdup_n_s32(0); + svint32_t bias_acc_44 = svdup_n_s32(0); + svint32_t bias_acc_66 = svdup_n_s32(0); + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3; + svint32_t q4sb_mins_0, q4sb_mins_1; + { + // 2-superblock I am working on + const int offset = sb * 24 + 0 * 12; + const uint8_t * scales_in = &q4_ptr[b].scales[offset]; + + const int offset1 = sb * 24 + 12; + const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1]; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + constexpr uint8_t scales_size = 12; + + uint32_t sm[3]; + memcpy(sm, scales_in, scales_size); + + uint32_t sm1[3]; + memcpy(sm1, scales_in1, scales_size); + + const uint32_t mins_0_3 = sm[1] & kmask1; + const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); + + const uint32_t mins_0_3_1 = sm1[1] & kmask1; + const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); + + svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7)); + svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1)); + + /* reinterpret u32 → u8 */ + svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); + svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); + + /* widen u8 → u16->u32 (lower half only) */ + svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8)); + svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1)); + + q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); + q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); + + uint32_t scales_u32_0 = sm[0] & kmask1; + uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); + uint32_t scales_u32_2 = sm1[0] & kmask1; + uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4); + + svuint32_t S01 = svdup_n_u32(scales_u32_0); + svuint32_t S23 = svdup_n_u32(scales_u32_1); + svuint32_t R01 = svdup_n_u32(scales_u32_2); + svuint32_t R23 = svdup_n_u32(scales_u32_3); + + svint8_t S01_b = svreinterpret_s8_u32(S01); + svint8_t S23_b = svreinterpret_s8_u32(S23); + svint8_t R01_b = svreinterpret_s8_u32(R01); + svint8_t R23_b = svreinterpret_s8_u32(R23); + + svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b))); + svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b))); + svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b))); + svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b))); + + block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx); + block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx); + block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx); + block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx); + } + + const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256; + + // Load 32-byte per row pair, 1 subblock each time + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); + svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); + svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); + svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); + + svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); + svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); + svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); + svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + + sb_acc_0 = svdup_n_s32(0); + sb_acc_2 = svdup_n_s32(0); + + svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); + svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); + svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); + svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); + + svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); + svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); + svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); + svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7); + + if(cp == 0) { + acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0); + acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0); + } + if(cp == 1) { + acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1); + acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1); + } + if(cp == 2) { + acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2); + acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2); + } + if(cp == 3) { + acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3); + acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3); + } + } + + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0); + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1); + + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0); + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1); + + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0); + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1); + + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0); + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1); + } // for sb + + + acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4)); + acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4)); + acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4)); + acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4)); + acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4)); + acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4)); + acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4)); + acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4)); + + svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1); + svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1); + + svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1); + svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1); + + // Broadcast q8 scalar + svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]); + + svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0))); + + svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0))); + + svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1); + acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[1]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1); + acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[2]); + + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1); + acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[3]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1); + acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1); + + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + // Predicate for exactly 4 lanes + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + + if (i == 0 && j == 0) { + // acc_f32_0 → lower half of acc_f32_01 + svst1_f32(pg4, s + offset, acc_f32_01); + } else if (i == 0 && j == 1) { + // acc_f32_1 → upper half of acc_f32_01 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4)); + } else if (i == 1 && j == 0) { + // acc_f32_2 + svst1_f32(pg4, s + offset, acc_f32_23); + } else if (i == 1 && j == 1) { + // acc_f32_3 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4)); + } else if (i == 2 && j == 0) { + // acc_f32_4 + svst1_f32(pg4, s + offset, acc_f32_45); + } else if (i == 2 && j == 1) { + // acc_f32_5 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4)); + } else if (i == 3 && j == 0) { + // acc_f32_6 + svst1_f32(pg4, s + offset, acc_f32_67); + } else if (i == 3 && j == 1) { + // acc_f32_7 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4)); + } + } + } + } // for x + } // for y + return; + } +#endif // SVE compile-time end + #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) constexpr int q8_k_blocklen = 4; const uint8x16_t m4b = vdupq_n_u8(0x0f); From d5dfc330279a76fa25614db1c6413599e8974da3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Feb 2026 09:21:11 +0200 Subject: [PATCH 02/10] graph : fix KQ mask, lora, cvec reuse checks (#19644) * graph : fix KQ mask reuse condition * cont : dedup KQ mask build and can_reuse * cont : fix build * graph : fix adapter check for reuse --- src/llama-adapter.h | 3 ++ src/llama-context.cpp | 18 +++++---- src/llama-context.h | 7 ++-- src/llama-graph.cpp | 93 +++++++++++++++++++++++-------------------- 4 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/llama-adapter.h b/src/llama-adapter.h index d275d25425..aa3ab63ad7 100644 --- a/src/llama-adapter.h +++ b/src/llama-adapter.h @@ -39,6 +39,8 @@ private: std::vector tensors; // per layer }; +using llama_adapter_cvec_ptr = std::shared_ptr; + // // llama_adapter_lora // @@ -84,3 +86,4 @@ struct llama_adapter_lora { }; using llama_adapter_loras = std::unordered_map; +using llama_adapter_loras_ptr = std::unique_ptr; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 99035b6cac..fc05989aa5 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -22,6 +22,8 @@ llama_context::llama_context( const llama_model & model, llama_context_params params) : model(model), + cvec(std::make_unique()), + loras(std::make_unique()), balloc(std::make_unique(model.hparams.n_pos_per_embd())) { // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, // may need to be backend-dependent @@ -1065,11 +1067,11 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a return; } - loras.clear(); + loras.reset(new llama_adapter_loras()); for (size_t i = 0; i < n_adapters; i ++) { if (scales[i] != 0.0f) { - loras[adapters[i]] = scales[i]; + loras->insert({adapters[i], scales[i]}); } } @@ -1079,14 +1081,14 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - if (n_adapters != loras.size()) { + if (n_adapters != loras->size()) { return false; } for (size_t i = 0; i < n_adapters; i ++) { - auto it = loras.find(adapters[i]); + auto it = loras->find(adapters[i]); - if (it == loras.end() || it->second != scales[i]) { + if (it == loras->end() || it->second != scales[i]) { return false; } } @@ -1104,7 +1106,7 @@ bool llama_context::set_adapter_cvec( // TODO: should we reserve? - return cvec.apply(model, data, len, n_embd, il_start, il_end); + return cvec->apply(model, data, len, n_embd, il_start, il_end); } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -2081,8 +2083,8 @@ llm_graph_params llama_context::graph_params( /*.gtype =*/ gtype, /*.sched =*/ sched.get(), /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, + /*.cvec =*/ cvec.get(), + /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, diff --git a/src/llama-context.h b/src/llama-context.h index a8e53f335c..e0d0085c1c 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -256,9 +256,10 @@ private: const llama_model & model; - llama_cparams cparams; - llama_adapter_cvec cvec; - llama_adapter_loras loras; + llama_cparams cparams; + + llama_adapter_cvec_ptr cvec; + llama_adapter_loras_ptr loras; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index bba747d37b..70d8ff02a9 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -17,6 +17,41 @@ #include #include +// dedup helpers + +static ggml_tensor * build_kq_mask( + ggml_context * ctx, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); +} + +static bool can_reuse_kq_mask( + ggml_tensor * kq_mask, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + bool res = true; + + res &= (kq_mask->ne[0] == n_kv); + res &= (kq_mask->ne[1] == n_tokens/n_stream); + res &= (kq_mask->ne[2] == 1); + res &= (kq_mask->ne[3] == n_stream); + + return res; +} + +// impl + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -403,8 +438,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -424,8 +458,7 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -455,11 +488,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; - - res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); - res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); return res; } @@ -521,8 +551,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -565,8 +594,7 @@ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -625,8 +653,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); } // swa tensors may not be allocated if there are no SWA attention layers @@ -634,8 +661,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv(); - res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); } res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -1891,14 +1917,11 @@ static std::unique_ptr build_attn_inp_kv_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); + ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1983,13 +2006,9 @@ static std::unique_ptr build_attn_inp_k_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -2188,15 +2207,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const auto inp = std::make_unique(hparams, cparams, mctx_cur); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = mctx_cur->get_base()->get_n_kv(); - inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); ggml_set_input(inp->self_kq_mask); ggml_set_name(inp->self_kq_mask, "self_kq_mask"); @@ -2207,12 +2222,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA"); - const auto n_kv = mctx_cur->get_swa()->get_n_kv(); - inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); ggml_set_input(inp->self_kq_mask_swa); ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); @@ -2374,27 +2387,21 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = attn_ctx->get_base()->get_n_kv(); - inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); ggml_set_input(inp_attn->self_kq_mask); inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; } { - const auto n_kv = attn_ctx->get_swa()->get_n_kv(); - inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); ggml_set_input(inp_attn->self_kq_mask_swa); inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; From cc45f2ada695644c6697c0fb0e70a5e95563ad0f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Feb 2026 14:35:04 +0200 Subject: [PATCH 03/10] models : deduplicate delta-net graphs for Qwen family (#19597) * models : add llm_build_delta_net_base * cont : keep qwen35 and qwen35moe graphs intact * cont : add comments --- src/CMakeLists.txt | 15 +- src/models/delta-net-base.cpp | 333 ++++++++++++++++++ src/models/falcon-h1.cpp | 4 +- src/models/granite-hybrid.cpp | 2 +- src/models/jamba.cpp | 2 +- src/models/kimi-linear.cpp | 4 +- ...graph-context-mamba.cpp => mamba-base.cpp} | 8 +- src/models/mamba.cpp | 3 +- src/models/models.h | 96 ++--- src/models/nemotron-h.cpp | 10 +- src/models/plamo2.cpp | 4 +- src/models/qwen35.cpp | 5 +- src/models/qwen35moe.cpp | 5 +- src/models/qwen3next.cpp | 325 +---------------- src/models/rwkv6-base.cpp | 2 + src/models/rwkv7-base.cpp | 2 + 16 files changed, 428 insertions(+), 392 deletions(-) create mode 100644 src/models/delta-net-base.cpp rename src/models/{graph-context-mamba.cpp => mamba-base.cpp} (97%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fdda05d3ea..daf249422a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,13 +57,14 @@ add_library(llama models/deci.cpp models/deepseek.cpp models/deepseek2.cpp + models/delta-net-base.cpp models/dots1.cpp models/dream.cpp models/ernie4-5-moe.cpp models/ernie4-5.cpp + models/exaone-moe.cpp models/exaone.cpp models/exaone4.cpp - models/exaone-moe.cpp models/falcon-h1.cpp models/falcon.cpp models/gemma-embedding.cpp @@ -91,10 +92,12 @@ add_library(llama models/llama-iswa.cpp models/llama.cpp models/maincoder.cpp + models/mamba-base.cpp models/mamba.cpp models/mimo2-iswa.cpp models/minicpm3.cpp models/minimax-m2.cpp + models/mistral3.cpp models/modern-bert.cpp models/mpt.cpp models/nemotron-h.cpp @@ -118,12 +121,12 @@ add_library(llama models/qwen2moe.cpp models/qwen2vl.cpp models/qwen3.cpp - models/qwen3vl.cpp - models/qwen3vl-moe.cpp - models/qwen3moe.cpp - models/qwen3next.cpp models/qwen35.cpp models/qwen35moe.cpp + models/qwen3moe.cpp + models/qwen3next.cpp + models/qwen3vl-moe.cpp + models/qwen3vl.cpp models/refact.cpp models/rnd1.cpp models/rwkv6-base.cpp @@ -142,8 +145,6 @@ add_library(llama models/t5-enc.cpp models/wavtokenizer-dec.cpp models/xverse.cpp - models/mistral3.cpp - models/graph-context-mamba.cpp ) set_target_properties(llama PROPERTIES diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp new file mode 100644 index 0000000000..0cdf9c324b --- /dev/null +++ b/src/models/delta-net-base.cpp @@ -0,0 +1,333 @@ +#include "models.h" + +#define CHUNK_SIZE 64 + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {} + +std::pair llm_build_delta_net_base::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs] + b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs] + + const int CS = CHUNK_SIZE; + + const int pad = (CS - n_tokens % CS) % CS; + const int n_chunks = (n_tokens + pad) / CS; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, 0, pad, 0, 0); + b = ggml_pad(ctx0, b, 0, pad, 0, 0); + + ggml_tensor * v_b = ggml_mul(ctx0, v, b); + ggml_tensor * k_b = ggml_mul(ctx0, k, b); + + cb(v_b, "v_b", il); + cb(k_b, "k_b", il); + + q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); + k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); + v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); + + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_cs = ggml_cumsum(ctx0, g); + cb(g_cs, "g_cs", il); + + ggml_tensor * g_cs_i = g_cs; + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); + + // [CS, CS, n_chunks, H_v * n_seqs] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * kb; + kb = ggml_mul_mat(ctx0, k, k_b); + kb = ggml_mul (ctx0, kb, decay_mask); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * attn; + attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity; + identity = ggml_view_1d(ctx0, attn, CS, 0); + identity = ggml_fill (ctx0, identity, 1.0f); + identity = ggml_diag (ctx0, identity); + + ggml_tensor * lhs = ggml_add(ctx0, attn, identity); + cb(lhs, "dnet_add_ch_lhs", il); + + attn = ggml_neg(ctx0, attn); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_add(ctx0, lin_solve, identity); + cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] + + // [S_v, CS, n_chunks, H_v * n_seqs] + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); + + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); + + k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); + + // [CS, S_k, n_chunks, H_k * n_seqs] + ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); + cb(kbg, "k_beta_g_exp", il); + + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); + cb(k_cd, "k_cumdecay", il); + + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp); + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + kq = ggml_mul(ctx0, kq, decay_mask); + kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); + cb(kq, "kq", il); + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along CS dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + // [1, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3], + g_cs->nb[1], + g_cs->nb[2], + g_cs->nb[3], + ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); + cb(g_last, "g_last", il); + + // TODO: remove this cont when CUDA supports non-cont unary ops + g_last = ggml_cont(ctx0, g_last); + + // [1, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); + + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp); + + // [S_k, CS, n_chunks, H_v * n_seqs] + ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); + cb(kg, "key_gdiff", il); + + // [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); + cb(kg_t, "key_gdiff_t", il); + + ggml_tensor * s_t = ggml_transpose(ctx0, s); + s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); + cb(s_t, "dnet_add_ch_state", il); + + // [CS, S_v, n_chunks, H_v * n_seqs] + ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] + ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] + ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); + cb(v_t_p, "v_prime", il); + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); + cb(v_t_new, "v_t_new", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); + cb(v_attn, "v_attn", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); + cb(attn_inter, "attn_inter", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); + cb(o_ch, "dnet_add_ch_attn_out", il); + + v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // TODO: head broadcast might not work here - probably will need a transpose + ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk); + s_t = ggml_mul(ctx0, s_t, ch_g_last_exp); + s_t = ggml_add(ctx0, s_t, kgv); + cb(s_t, "dnet_add_ch_state", il); + } + + s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); + + // truncate padded tokens + ggml_tensor * o = ggml_view_4d(ctx0, v, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(v->type, S_v), + ggml_row_size(v->type, S_v * CS * n_chunks), + ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); + + o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + + return {o, s}; +} + +std::pair llm_build_delta_net_base::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, // beta + ggml_tensor * s, // state + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); + + // [S_v, S_v, H_v, n_seqs] + g = ggml_exp(ctx0, g); + s = ggml_mul(ctx0, s, g); + + ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * sk; + sk = ggml_mul (ctx0, s_t, k); + sk = ggml_sum_rows(ctx0, sk); + + // [S_v, 1, H_v, n_seqs] + ggml_tensor * d; + d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); + d = ggml_mul(ctx0, d, b); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * d_t; + d_t = ggml_transpose(ctx0, d); + + // [S_v, S_v, H_v, n_seqs] + ggml_tensor * kd; + k = ggml_repeat(ctx0, k, s); + kd = ggml_mul (ctx0, k, d_t); + + s_t = ggml_add(ctx0, s_t, kd); + + cb(s_t, "dnet_add_ar_state", il); + + ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); + ggml_tensor * o = ggml_sum_rows(ctx0, s_q); + + o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + + return {o, s}; +} diff --git a/src/models/falcon-h1.cpp b/src/models/falcon-h1.cpp index b641a09407..785a7e5e66 100644 --- a/src/models/falcon-h1.cpp +++ b/src/models/falcon-h1.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/src/models/granite-hybrid.cpp b/src/models/granite-hybrid.cpp index f6ca4c17a2..726ecdcca7 100644 --- a/src/models/granite-hybrid.cpp +++ b/src/models/granite-hybrid.cpp @@ -2,7 +2,7 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/models/jamba.cpp b/src/models/jamba.cpp index a0187772cc..ceab581740 100644 --- a/src/models/jamba.cpp +++ b/src/models/jamba.cpp @@ -1,6 +1,6 @@ #include "models.h" -llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 942844d071..133834021d 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -1,6 +1,8 @@ #include "models.h" #include "ggml.h" +#include "llama-memory-recurrent.h" + #define CHUNK_SIZE 64 // Causal Conv1d function for Q,K,V @@ -65,7 +67,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t } llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_mamba_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/graph-context-mamba.cpp b/src/models/mamba-base.cpp similarity index 97% rename from src/models/graph-context-mamba.cpp rename to src/models/mamba-base.cpp index b9a363b32b..aaac9487df 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/mamba-base.cpp @@ -1,8 +1,10 @@ #include "models.h" -llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} +#include "llama-memory-recurrent.h" -ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, +llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {} + +ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -143,7 +145,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in return cur; } -ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp, +ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, diff --git a/src/models/mamba.cpp b/src/models/mamba.cpp index 46819613c2..55fd2e055c 100644 --- a/src/models/mamba.cpp +++ b/src/models/mamba.cpp @@ -1,7 +1,6 @@ #include "models.h" - -llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/models.h b/src/models/models.h index ec6f80e526..920a8e5798 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1,23 +1,51 @@ #pragma once -#include "../llama-model.h" -#include "../llama-graph.h" +#include "llama-model.h" +#include "llama-graph.h" -// TODO: remove in follow-up PR - move to .cpp files -#include "../llama-memory-recurrent.h" +// note: almost all graphs require atleast sqrtf, so include cmath globally #include -struct llm_graph_context_mamba : public llm_graph_context { - llm_graph_context_mamba(const llm_graph_params & params); +// +// base classes +// - virtual ~llm_graph_context_mamba() = default; +struct llm_build_mamba_base : public llm_graph_context { + llm_build_mamba_base(const llm_graph_params & params); + + virtual ~llm_build_mamba_base() = default; ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const; }; -// Base class for RWKV-related models +struct llm_build_delta_net_base : public llm_graph_context { + llm_build_delta_net_base(const llm_graph_params & params); + + virtual ~llm_build_delta_net_base() = default; + + // returns pair of output and new state + std::pair build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // returns pair of output and new state + std::pair build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); +}; + struct llm_build_rwkv6_base : public llm_graph_context { const llama_model & model; @@ -58,6 +86,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { int il) const; }; +// +// models +// + struct llm_build_afmoe : public llm_graph_context { llm_build_afmoe(const llama_model & model, const llm_graph_params & params); }; @@ -175,7 +207,7 @@ struct llm_build_falcon : public llm_graph_context { llm_build_falcon(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_falcon_h1 : public llm_graph_context_mamba { +struct llm_build_falcon_h1 : public llm_build_mamba_base { llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params); }; @@ -253,7 +285,7 @@ private: const int il); }; -struct llm_build_granite_hybrid : public llm_graph_context_mamba { +struct llm_build_granite_hybrid : public llm_build_mamba_base { llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params); ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, @@ -284,11 +316,12 @@ struct llm_build_jais : public llm_graph_context { llm_build_jais(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_jamba : public llm_graph_context_mamba { +struct llm_build_jamba : public llm_build_mamba_base { llm_build_jamba(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_kimi_linear : public llm_graph_context_mamba { +// TODO: derive llm_build_delta_net_base instead +struct llm_build_kimi_linear : public llm_build_mamba_base { llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); std::pair build_kda_autoregressive( @@ -347,7 +380,7 @@ struct llm_build_maincoder : public llm_graph_context { llm_build_maincoder(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_mamba : public llm_graph_context_mamba { +struct llm_build_mamba : public llm_build_mamba_base { llm_build_mamba(const llama_model & model, const llm_graph_params & params); }; @@ -379,11 +412,11 @@ struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_nemotron_h : public llm_graph_context_mamba { +struct llm_build_nemotron_h : public llm_build_mamba_base { llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il); + ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, - const llama_model & model, const int64_t n_embd_head, const int il); + const llama_model & model, int64_t n_embd_head, int il); }; struct llm_build_neo_bert : public llm_graph_context { @@ -428,7 +461,7 @@ struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_plamo2 : public llm_graph_context_mamba { +struct llm_build_plamo2 : public llm_build_mamba_base { llm_build_plamo2(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); @@ -477,7 +510,7 @@ struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_qwen3next : public llm_graph_context_mamba { +struct llm_build_qwen3next : public llm_build_delta_net_base { llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -495,26 +528,6 @@ private: ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -529,7 +542,8 @@ private: const llama_model & model; }; -struct llm_build_qwen35 : public llm_graph_context_mamba { +// TODO: derive llm_build_delta_net_base instead +struct llm_build_qwen35 : public llm_graph_context { llm_build_qwen35(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -547,6 +561,7 @@ private: ggml_tensor * diag_mask, int il); + ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); @@ -588,7 +603,8 @@ private: const llama_model & model; }; -struct llm_build_qwen35moe : public llm_graph_context_mamba { +// TODO: derive llm_build_delta_net_base instead +struct llm_build_qwen35moe : public llm_graph_context { llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index 079c730ac2..d61d62a8c9 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -65,8 +63,8 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, const llama_model & model, - const int64_t n_embd_head, - const int il) { + int64_t n_embd_head, + int il) { // compute Q and K ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -106,7 +104,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * return cur; } -ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { +ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, diff --git a/src/models/plamo2.cpp b/src/models/plamo2.cpp index 31115a08f9..3af236843b 100644 --- a/src/models/plamo2.cpp +++ b/src/models/plamo2.cpp @@ -1,7 +1,9 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 592c170457..94c68dbb26 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -1,10 +1,11 @@ -#include "ggml.h" #include "models.h" +#include "llama-memory-recurrent.h" + #define CHUNK_SIZE 64 llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_graph_context(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 0db8f825c6..93da7ea628 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -1,10 +1,11 @@ -#include "ggml.h" #include "models.h" +#include "llama-memory-recurrent.h" + #define CHUNK_SIZE 64 llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_graph_context(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index aea8b29513..0fdf2d42c2 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -1,10 +1,9 @@ -#include "ggml.h" #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -83,326 +82,6 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); } -std::pair llm_build_qwen3next::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * b, - ggml_tensor * s, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(S_k == S_v); - GGML_ASSERT(H_v % H_k == 0); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); - GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - - const float scale = 1.0f / sqrtf(S_k); - - q = ggml_scale(ctx0, q, scale); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(b, "b_in", il); - cb(g, "g_in", il); - - q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] - g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs] - b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs] - - const int CS = CHUNK_SIZE; - - const int pad = (CS - n_tokens % CS) % CS; - const int n_chunks = (n_tokens + pad) / CS; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, 0, pad, 0, 0); - b = ggml_pad(ctx0, b, 0, pad, 0, 0); - - ggml_tensor * v_b = ggml_mul(ctx0, v, b); - ggml_tensor * k_b = ggml_mul(ctx0, k, b); - - cb(v_b, "v_b", il); - cb(k_b, "k_b", il); - - q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); - k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); - v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs); - b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); - - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_cs = ggml_cumsum(ctx0, g); - cb(g_cs, "g_cs", il); - - ggml_tensor * g_cs_i = g_cs; - ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); - - g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); - - // [CS, CS, n_chunks, H_v * n_seqs] - ggml_tensor * decay_mask; - decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); - decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); - decay_mask = ggml_exp(ctx0, decay_mask); - cb(decay_mask, "decay_mask", il); - - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * kb; - kb = ggml_mul_mat(ctx0, k, k_b); - kb = ggml_mul (ctx0, kb, decay_mask); - - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * attn; - attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity; - identity = ggml_view_1d(ctx0, attn, CS, 0); - identity = ggml_fill (ctx0, identity, 1.0f); - identity = ggml_diag (ctx0, identity); - - ggml_tensor * lhs = ggml_add(ctx0, attn, identity); - cb(lhs, "dnet_add_ch_lhs", il); - - attn = ggml_neg(ctx0, attn); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_add(ctx0, lin_solve, identity); - cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] - - // [S_v, CS, n_chunks, H_v * n_seqs] - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); - - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); - - k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); - - // [CS, S_k, n_chunks, H_k * n_seqs] - ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); - cb(kbg, "k_beta_g_exp", il); - - // [S_k, CS, n_chunks, H_k * n_seqs] - ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); - cb(k_cd, "k_cumdecay", il); - - // [S_k, CS, n_chunks, H_k * n_seqs] - ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp); - ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); - - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - kq = ggml_mul(ctx0, kq, decay_mask); - kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); - cb(kq, "kq", il); - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along CS dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - // [1, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3], - g_cs->nb[1], - g_cs->nb[2], - g_cs->nb[3], - ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); - cb(g_last, "g_last", il); - - // TODO: remove this cont when CUDA supports non-cont unary ops - g_last = ggml_cont(ctx0, g_last); - - // [1, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); - - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); - cb(g_diff, "g_diff", il); - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp); - - // [S_k, CS, n_chunks, H_v * n_seqs] - ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); - cb(kg, "key_gdiff", il); - - // [CS, S_k, n_chunks, H_v * n_seqs] - ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); - cb(kg_t, "key_gdiff_t", il); - - ggml_tensor * s_t = ggml_transpose(ctx0, s); - s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); - cb(s_t, "dnet_add_ch_state", il); - - // [CS, S_v, n_chunks, H_v * n_seqs] - ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] - ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] - ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] - ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] - ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] - - // [CS, S_v, 1, H_v * n_seqs] - ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); - cb(v_t_p, "v_prime", il); - - // [CS, S_v, 1, H_v * n_seqs] - ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); - cb(v_t_new, "v_t_new", il); - - // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); - cb(v_attn, "v_attn", il); - - // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); - cb(attn_inter, "attn_inter", il); - - // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); - cb(o_ch, "dnet_add_ch_attn_out", il); - - v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // TODO: head broadcast might not work here - probably will need a transpose - ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk); - s_t = ggml_mul(ctx0, s_t, ch_g_last_exp); - s_t = ggml_add(ctx0, s_t, kgv); - cb(s_t, "dnet_add_ch_state", il); - } - - s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); - - // truncate padded tokens - ggml_tensor * o = ggml_view_4d(ctx0, v, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(v->type, S_v), - ggml_row_size(v->type, S_v * CS * n_chunks), - ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); - - o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] - - return {o, s}; -} - -std::pair llm_build_qwen3next::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * b, // beta - ggml_tensor * s, // state - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); - - GGML_ASSERT(S_k == S_v); - GGML_ASSERT(H_v % H_k == 0); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); - GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - - const float scale = 1.0f / sqrtf(S_k); - - q = ggml_scale(ctx0, q, scale); - - q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(b, "b_in", il); - cb(g, "g_in", il); - - g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs); - b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); - - // [S_v, S_v, H_v, n_seqs] - g = ggml_exp(ctx0, g); - s = ggml_mul(ctx0, s, g); - - ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); - - // [1, S_v, H_v, n_seqs] - ggml_tensor * sk; - sk = ggml_mul (ctx0, s_t, k); - sk = ggml_sum_rows(ctx0, sk); - - // [S_v, 1, H_v, n_seqs] - ggml_tensor * d; - d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); - d = ggml_mul(ctx0, d, b); - - // [1, S_v, H_v, n_seqs] - ggml_tensor * d_t; - d_t = ggml_transpose(ctx0, d); - - // [S_v, S_v, H_v, n_seqs] - ggml_tensor * kd; - k = ggml_repeat(ctx0, k, s); - kd = ggml_mul (ctx0, k, d_t); - - s_t = ggml_add(ctx0, s_t, kd); - - cb(s_t, "dnet_add_ar_state", il); - - ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); - ggml_tensor * o = ggml_sum_rows(ctx0, s_q); - - o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] - - return {o, s}; -} - ggml_tensor * llm_build_qwen3next::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, diff --git a/src/models/rwkv6-base.cpp b/src/models/rwkv6-base.cpp index 7beed2daff..83aeab7280 100644 --- a/src/models/rwkv6-base.cpp +++ b/src/models/rwkv6-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/src/models/rwkv7-base.cpp b/src/models/rwkv7-base.cpp index cda4465384..7fcab77745 100644 --- a/src/models/rwkv7-base.cpp +++ b/src/models/rwkv7-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} From 2ba9adc093127016a48cd0c5d6bf1420dafe17a6 Mon Sep 17 00:00:00 2001 From: Mario Limonciello Date: Mon, 16 Feb 2026 07:46:08 -0600 Subject: [PATCH 04/10] Adjust workaround for ROCWMMA_FATTN/GFX9 to only newer ROCm veresions (#19591) Avoids issues with ROCm 6.4.4. Closes: https://github.com/ggml-org/llama.cpp/issues/19580 Fixes: 6845f7f87 ("Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)") Signed-off-by: Mario Limonciello (AMD) --- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 35735d48b2..f19defbff9 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -63,7 +63,7 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); -#if defined(GGML_USE_HIP) +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 typedef wmma::fragment frag_a_K; typedef wmma::fragment frag_a_V; typedef wmma::fragment frag_b; @@ -135,7 +135,7 @@ static __global__ void flash_attn_ext_f16( __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; -#if defined(GGML_USE_HIP) +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 const _Float16 * K_h_f16 = reinterpret_cast(K_h); const _Float16 * V_h_f16 = reinterpret_cast(V_h); _Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ); From 44084941448d841c9b12ad250da5619cddb58bd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Mon, 16 Feb 2026 16:06:48 +0100 Subject: [PATCH 05/10] build : rework llama_option_depr to handle LLAMA_CURL (#19658) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- CMakeLists.txt | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d10ab6da96..32542ecd27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -115,11 +115,6 @@ option(LLAMA_TESTS_INSTALL "llama: install tests" ON) option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON) option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) -# deprecated -option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) -if (LLAMA_CURL) - message(WARNING "LLAMA_CURL option is deprecated and will be ignored") -endif() # Required for relocatable CMake package include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) @@ -147,10 +142,15 @@ if (NOT DEFINED GGML_CUDA_GRAPHS) endif() # transition helpers -function (llama_option_depr TYPE OLD NEW) +function (llama_option_depr TYPE OLD) if (${OLD}) - message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n") - set(${NEW} ON PARENT_SCOPE) + set(NEW "${ARGV2}") + if(NEW) + message(${TYPE} "${OLD} is deprecated, use ${NEW} instead") + set(${NEW} ON PARENT_SCOPE) + else() + message(${TYPE} "${OLD} is deprecated and will be ignored") + endif() endif() endfunction() @@ -163,6 +163,7 @@ llama_option_depr(WARNING LLAMA_RPC GGML_RPC) llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL) llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) llama_option_depr(WARNING LLAMA_CANN GGML_CANN) +llama_option_depr(WARNING LLAMA_CURL) include("cmake/license.cmake") license_add_file("llama.cpp" "LICENSE") From 5f28c53d11210f3521328d6dac620c4b6ae0044b Mon Sep 17 00:00:00 2001 From: Saurabh Dash <111897126+saurabhdash2512@users.noreply.github.com> Date: Mon, 16 Feb 2026 10:28:46 -0500 Subject: [PATCH 06/10] model: Add support for Tiny Aya Models (#19611) * changes for tiny aya * changes to hash * changes to vocab * fix some tokenizer regex edge cases * update comment * add some comments for regex * Apply suggestion from @ngxson --------- Co-authored-by: Xuan-Son Nguyen --- convert_hf_to_gguf.py | 14 ++++++++++++++ convert_hf_to_gguf_update.py | 1 + src/llama-vocab.cpp | 16 ++++++++++++++-- src/llama-vocab.h | 1 + src/unicode.cpp | 6 ++++++ 5 files changed, 36 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0f614e4df3..d7141f01cf 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1124,6 +1124,9 @@ class TextModel(ModelBase): if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 res = "command-r" + if chkhsh == "d772b220ace2baec124bed8cfafce0ead7d6c38a4b65ef11261cf9d5d62246d1": + # ref: https://huggingface.co/CohereLabs/tiny-aya-base + res = "tiny_aya" if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea": # ref: https://huggingface.co/Qwen/Qwen1.5-7B res = "qwen2" @@ -7360,6 +7363,17 @@ class Cohere2Model(TextModel): self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads))) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Cohere2 runtime in llama.cpp expects no bias tensors; + # the actual weight only contains 0-value tensors as bias, we can skip them + if name.endswith(".bias"): + if torch.any(data_torch != 0): + raise ValueError(f"Bias tensor {name!r} is not zero.") + logger.debug(f"Skipping bias tensor {name!r} for Cohere2 conversion.") + return + + yield from super().modify_tensors(data_torch, name, bid) + @ModelBase.register("OlmoForCausalLM") @ModelBase.register("OLMoForCausalLM") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index a683451508..8bd24dbe91 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -99,6 +99,7 @@ models = [ {"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", }, {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, + {"name": "tiny_aya", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereLabs/tiny-aya-base", }, {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 62e137fb84..b35cb02ce4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -422,6 +422,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: + regex_exprs = { + // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" + "\\d{1,3}(?=(?:\\d{3})*\\b)", + // original regex from tokenizer.json: "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: regex_exprs = { // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp @@ -2005,10 +2013,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else if ( - tokenizer_pre == "gpt-4o" || - tokenizer_pre == "llama4") { + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "tiny_aya") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; + clean_spaces = false; } else if ( tokenizer_pre == "superbpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 718238fb86..1312a877ab 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -55,6 +55,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, + LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, }; struct LLM_KV; diff --git a/src/unicode.cpp b/src/unicode.cpp index b88d953bd2..1475b53b65 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -769,6 +769,12 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if (regex_expr == "\\p{AFMoE_digits}") { // AFMOE digit pattern - use custom implementation for proper splitting bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); + } else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") { + // tiny_aya digit grouping pattern from tokenizer.json: + // {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"} + // Splits digits into groups of 3 from the right (e.g., 1234567 -> 1, 234, 567) + // TODO: Revisit this regex, incase there are any subtle tokenization differences with the original regex. + bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); } return bpe_offsets; From d23a55997de9f42754e02f9022fefd2e4d41f06f Mon Sep 17 00:00:00 2001 From: Judd <4046440+foldl@users.noreply.github.com> Date: Mon, 16 Feb 2026 23:43:34 +0800 Subject: [PATCH 07/10] ggml : make `ggml_is_view` as API (#19539) * make `ggml_is_view` as API * introduce `ggml_aux_is_view` as inline version for internal use. * change `ggml_aux_is_view` to `ggml_impl_is_view` --- ggml/include/ggml.h | 1 + ggml/src/ggml-alloc.c | 13 ++++--------- ggml/src/ggml-impl.h | 4 ++++ ggml/src/ggml.c | 4 ++++ 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f759e2d588..77af0e7fb6 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -752,6 +752,7 @@ extern "C" { GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_view (const struct ggml_tensor * tensor); GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 41419b617b..7f414b2311 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -17,11 +17,6 @@ //#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__) #define AT_PRINTF(...) - -static bool ggml_is_view(const struct ggml_tensor * t) { - return t->view_src != NULL; -} - // ops that return true for this function must not use restrict pointers for their backend implementations bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { @@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor GGML_ASSERT(buffer_id >= 0); struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); - if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) { + if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) { hn->allocated = true; assert(hn->addr.offset == 0); @@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent); if (p_hn->n_children == 1 && p_hn->n_views == 0) { - if (ggml_is_view(parent)) { + if (ggml_impl_is_view(parent)) { struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { @@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node // itself is never used and should not be considered a dependency - if (ggml_is_view(node) && node->op != GGML_OP_NONE) { + if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) { struct ggml_tensor * view_src = node->view_src; ggml_gallocr_hash_get(galloc, view_src)->n_views += 1; } @@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated); if (p_hn->n_children == 0 && p_hn->n_views == 0) { - if (ggml_is_view(parent)) { + if (ggml_impl_is_view(parent)) { struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); view_src_hn->n_views -= 1; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index baadfe9a7b..e3714b38a6 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } +static inline bool ggml_impl_is_view(const struct ggml_tensor * t) { + return t->view_src != NULL; +} + static inline float ggml_compute_softplus_f32(float input) { return (input > 20.0f) ? input : logf(1 + expf(input)); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e2a6ff67be..ed819eaa4c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1496,6 +1496,10 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso (t0->nb[3] == t1->nb[3]); } +bool ggml_is_view(const struct ggml_tensor * t) { + return ggml_impl_is_view(t); +} + // check if t1 can be represented as a repetition of t0 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); From cceb1b4e33cfd9595b4ac1949f2c0857e43af427 Mon Sep 17 00:00:00 2001 From: Ivan Chikish Date: Mon, 16 Feb 2026 18:52:24 +0300 Subject: [PATCH 08/10] common : inline functions (#18639) --- common/common.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/common.h b/common/common.h index 804485fb19..6410248377 100644 --- a/common/common.h +++ b/common/common.h @@ -670,7 +670,7 @@ static std::vector string_split(const std::string & str, char delim) { } template<> -std::vector string_split(const std::string & input, char separator) +inline std::vector string_split(const std::string & input, char separator) { std::vector parts; size_t begin_pos = 0; @@ -685,7 +685,7 @@ std::vector string_split(const std::string & input, ch return parts; } -static bool string_starts_with(const std::string & str, +inline bool string_starts_with(const std::string & str, const std::string & prefix) { // While we wait for C++20's std::string::starts_with... return str.rfind(prefix, 0) == 0; } @@ -870,11 +870,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps"; -static std::string llm_ffn_exps_block_regex(int idx) { +inline std::string llm_ffn_exps_block_regex(int idx) { return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX); } -static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() { +inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() { return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() }; } From d612901116ab2066c7923372d4827032ff296bc4 Mon Sep 17 00:00:00 2001 From: AesSedai <7980540+AesSedai@users.noreply.github.com> Date: Mon, 16 Feb 2026 08:44:44 -0800 Subject: [PATCH 09/10] perplexity: add proper batching (#19661) --- tools/perplexity/perplexity.cpp | 154 ++++++++++++++++++-------------- 1 file changed, 89 insertions(+), 65 deletions(-) diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index 1ead9c871e..433b747f0d 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -347,7 +347,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params int count = 0; double nll = 0.0; - LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + const int n_seq = std::max(1, n_batch / n_ctx); + LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); for (int i = 0; i < n_chunk; ++i) { const int start = i * params.ppl_stride; @@ -1737,11 +1738,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } const int n_batch = params.n_batch; - const int num_batches = (n_ctx + n_batch - 1)/n_batch; + const int num_batches = (static_cast(n_ctx) + n_batch - 1) / n_batch; + // Calculate n_seq based on the logits file's n_ctx, but cap it at what the context supports + const int n_seq_max = llama_n_seq_max(ctx); + int n_seq = std::max(1, n_batch / static_cast(n_ctx)); + if (n_seq > n_seq_max) { + LOG_WRN("%s: calculated n_seq=%d exceeds context's n_seq_max=%d, capping at %d\n", + __func__, n_seq, n_seq_max, n_seq_max); + n_seq = n_seq_max; + } const int nv = 2*((n_vocab + 1)/2) + 4; const bool add_bos = llama_vocab_get_add_bos(vocab); GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); + llama_batch batch = llama_batch_init(std::min(n_batch, static_cast(n_ctx)*n_seq), 0, 1); + std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); @@ -1750,6 +1761,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { logits.reserve(size_t(n_ctx) * n_vocab); } + LOG_INF("%s: computing over %d chunks, n_ctx=%u, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); + std::vector workers(std::thread::hardware_concurrency() - 1); auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) { @@ -1774,107 +1787,122 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { auto kld_ptr = kld_values.data(); auto p_diff_ptr = p_diff_values.data(); - for (int i = 0; i < n_chunk; ++i) { + const int first = n_ctx/2; + + for (int i = 0; i < n_chunk; i += n_seq) { const int start = i * n_ctx; const int end = start + n_ctx; - const auto t_start = std::chrono::high_resolution_clock::now(); + const int n_seq_batch = std::min(n_seq, n_chunk - i); - if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) { - LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i); - return; - } + const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache llama_memory_clear(llama_get_memory(ctx), true); - llama_batch batch = llama_batch_init(n_batch, 0, 1); - for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - // save original token and restore it after eval - const auto token_org = tokens[batch_start]; - - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[batch_start] = llama_vocab_bos(vocab); - } + int n_outputs = 0; common_batch_clear(batch); - for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + for (int seq = 0; seq < n_seq_batch; seq++) { + int seq_start = batch_start + seq*n_ctx; + + // save original token and restore it after eval + const auto token_org = tokens[seq_start]; + + // add BOS token for the first batch of each chunk + if (add_bos && j == 0) { + tokens[seq_start] = llama_vocab_bos(vocab); + } + + for (int k = 0; k < batch_size; ++k) { + const int pos = j*n_batch + k; + const bool need_logits = pos >= first; + common_batch_add(batch, tokens[seq_start + k], pos, { seq }, need_logits); + n_outputs += need_logits; + } + + // restore the original token in case it was set to BOS + tokens[seq_start] = token_org; } if (llama_decode(ctx, batch)) { - LOG_ERR("%s : failed to eval\n", __func__); + LOG_ERR("%s : failed to decode\n", __func__); llama_batch_free(batch); return; } - // restore the original token in case it was set to BOS - tokens[batch_start] = token_org; - - if (num_batches > 1) { + if (num_batches > 1 && n_outputs > 0) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab); } } - llama_batch_free(batch); - - const auto t_end = std::chrono::high_resolution_clock::now(); - if (i == 0) { + llama_synchronize(ctx); + const auto t_end = std::chrono::high_resolution_clock::now(); const float t_total = std::chrono::duration(t_end - t_start).count(); LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total); - int total_seconds = (int)(t_total * n_chunk); + int total_seconds = (int)(t_total * n_chunk / n_seq); if (total_seconds >= 60*60) { LOG("%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } LOG("%.2f minutes\n", total_seconds / 60.0); + LOG("\n"); + LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n"); } - LOG("\n"); - LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n"); - const int first = n_ctx/2; - const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); - process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr); - p_diff_ptr += n_ctx - 1 - first; - kld_ptr += n_ctx - 1 - first; + // Read log probs for each sequence in the batch + for (int seq = 0; seq < n_seq_batch; seq++) { + if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) { + LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i + seq); + llama_batch_free(batch); + return; + } - LOG("%4d", i+1); + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); - auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); - const double ppl_val = exp(log_ppl.first); - const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) - LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); + process_logits(n_vocab, all_logits, tokens.data() + start + seq*n_ctx + first, n_ctx - 1 - first, + workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr); + p_diff_ptr += n_ctx - 1 - first; + kld_ptr += n_ctx - 1 - first; - auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); - const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); - const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; - const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); - LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc); + LOG("%4d", i + seq + 1); - auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); - LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); + auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); + const double ppl_val = exp(log_ppl.first); + const double ppl_unc = ppl_val * log_ppl.second; + LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); - auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); - const double p_diff_rms_val = sqrt(p_diff_mse.first); - const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; - LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); + const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); + const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; + const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); + LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc); - double p_top_val = 1.*kld.n_same_top/kld.count; - double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1)); - LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc); + auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); + LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); - LOG("\n"); + auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); + const double p_diff_rms_val = sqrt(p_diff_mse.first); + const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; + LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + + double p_top_val = 1.*kld.n_same_top/kld.count; + double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1)); + LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc); + + LOG("\n"); + } logits.clear(); } + + llama_batch_free(batch); LOG("\n"); if (kld.count < 100) return; // we do not wish to do statistics on so few values @@ -1996,7 +2024,7 @@ int main(int argc, char ** argv) { const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence; - if (ppl) { + if (ppl || params.kl_divergence) { const int32_t n_seq = std::max(1, params.n_batch / n_ctx); const int32_t n_kv = n_seq * n_ctx; @@ -2006,12 +2034,8 @@ int main(int argc, char ** argv) { params.n_batch = std::min(params.n_batch, n_kv); } else { params.n_batch = std::min(params.n_batch, params.n_ctx); - if (params.kl_divergence) { - params.n_parallel = 1; - } else { - // ensure there's at least enough seq_ids for HellaSwag - params.n_parallel = std::max(4, params.n_parallel); - } + // ensure there's at least enough seq_ids for HellaSwag + params.n_parallel = std::max(4, params.n_parallel); } if (params.ppl_stride > 0) { From 05fa625eac5bbdbe88b43f857156c35501421d6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Mon, 16 Feb 2026 16:49:57 -0500 Subject: [PATCH 10/10] convert : add JoyAI-LLM-Flash (#19651) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * convert_hf_to_gguf: add JoyAI-LLM-Flash tokenizer hash mapping to deepseek-v3 * llama-vocab: create a new pre-tokenizer name for joyai-llm. * add missing vocab type section * Update convert_hf_to_gguf_update.py Co-authored-by: Sigbjørn Skjæret * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 9 ++++++--- convert_hf_to_gguf_update.py | 5 +++-- src/llama-vocab.cpp | 5 +++++ src/llama-vocab.h | 1 + 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d7141f01cf..0e5d0f8589 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1049,6 +1049,9 @@ class TextModel(ModelBase): if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": # ref: https://huggingface.co/zai-org/GLM-4.5-Air res = "glm4" + if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267": + # ref: https://huggingface.co/zai-org/GLM-4.7-Flash + res = "glm4" if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" @@ -1082,9 +1085,6 @@ class TextModel(ModelBase): if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df": # ref: https://huggingface.co/aari1995/German_Semantic_V3 res = "jina-v2-de" - if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267": - # ref: https://huggingface.co/zai-org/GLM-4.7-Flash - res = "glm4" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -1268,6 +1268,9 @@ class TextModel(ModelBase): if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4": # ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct res = "qwen35" + if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d": + # ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash + res = "joyai-llm" if res is None: logger.warning("\n") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 8bd24dbe91..f871b4cdb7 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -149,7 +149,8 @@ models = [ {"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", }, {"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", }, {"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", }, - {"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", } + {"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }, + {"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", }, ] # some models are known to be broken upstream, so we will skip them as exceptions @@ -159,6 +160,7 @@ pre_computed_hashes = [ {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, {"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"}, @@ -172,7 +174,6 @@ pre_computed_hashes = [ {"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"}, # jina-v2-de variants {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"}, - {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"}, ] diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index b35cb02ce4..80af181c52 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -308,6 +308,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: + case LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -2051,6 +2052,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan-dense") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; clean_spaces = false; + } else if ( + tokenizer_pre == "joyai-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM; + clean_spaces = false; } else if ( tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 1312a877ab..2df25fe620 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -56,6 +56,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, + LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, }; struct LLM_KV;