diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ee02cdd91c..e1c78e3b18 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -136,19 +136,11 @@ class ModelBase: self.remote_hf_model_id = remote_hf_model_id self.sentence_transformers_dense_modules = sentence_transformers_dense_modules self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams - self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {} self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters - if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters: - if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None: - self.rope_parameters["rope_theta"] = rope_theta - if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None: - self.rope_parameters["rope_type"] = rope_type - # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. @@ -765,6 +757,15 @@ class TextModel(ModelBase): self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {} + + # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters + if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters: + if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None: + self.rope_parameters["rope_theta"] = rope_theta + if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None: + self.rope_parameters["rope_type"] = rope_type + @classmethod def __init_subclass__(cls): # can't use an abstract property, because overriding it without type errors @@ -1203,6 +1204,9 @@ class TextModel(ModelBase): if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95": # ref: https://huggingface.co/MiniMaxAI/MiniMax-M2 res = "minimax-m2" + if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665": + # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer + res = "kormo" if res is None: logger.warning("\n") @@ -3398,7 +3402,7 @@ class QwenModel(TextModel): self._set_vocab_qwen() -@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration") +@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM") class Qwen2Model(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2 diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index b8f694e86c..5e8456a7ea 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -143,6 +143,7 @@ models = [ {"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", }, {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, + {"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", }, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 72a82a8911..514f086f68 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1976,9 +1976,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s break; case GGML_TYPE_F16: - if (!opt_experimental) { - return false; - } break; default: diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index c99b6a0d18..346f0bd339 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -903,7 +903,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const float * restrict vy = (const float * restrict) y; for (uint32_t i = 0; i < n; i++) { - rsum += vx[i] * (__fp16) vy[i]; + rsum += (float)vx[i] * vy[i]; } *s = rsum; return; @@ -917,7 +917,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri // for some reason we need volatile here so that the compiler doesn't try anything funky volatile HVX_Vector rsum = Q6_V_vsplat_R(0); - + float r_sum_scalar = 0.0f; uint32_t i = 0; for (i = 0; i < nv0; i++) { @@ -926,31 +926,42 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri HVX_Vector x = vx[i]; HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); + //NOTE: need volatile here to prevent compiler optimization + // Seem compiler cannot guarantee read-after-write?? + volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); + volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo); rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); } if (nv1) { - HVX_VectorPair yp = vy[i]; + // HVX_VectorPair yp = vy[i]; - HVX_Vector x = vx[i]; - HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 + // HVX_Vector x = vx[i]; + // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - if (nv1 >= 32) { - HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); - nv1 -= 32; - } + // if (nv1 >= 32) { + // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); + // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); + // nv1 -= 32; + // } + // rsum = hvx_vec_qf32_reduce_sum(rsum); + + // if (nv1) { + // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); + // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1); + // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + // } + + //process the remainder using scalar loop rsum = hvx_vec_qf32_reduce_sum(rsum); + const __fp16 * restrict sx = (const __fp16 * restrict) x; + const float * restrict sy = (const float * restrict) y; - if (nv1) { - HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + for (uint32_t i = nv0 * 64; i < n; i++) { + r_sum_scalar += (float) sx[i] * sy[i]; } // hvx_vec_dump_fp16("X", x); @@ -961,7 +972,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri rsum = hvx_vec_qf32_reduce_sum(rsum); } - *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)); + *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar; # ifdef HTP_DEBUG { @@ -1498,9 +1509,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = sizeof(__fp16) * ne00; - const size_t src1_row_size = sizeof(float) * ne10; - assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -1510,8 +1518,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0, // This is the size of the rest of the dimensions of the result const uint32_t nr1 = ne1 * ne2 * ne3; - uint32_t chunk_size = 64; - // distribute the thread work across the inner or outer loop based on which one is larger uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows @@ -1544,11 +1550,11 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0, const uint32_t blck_0 = 64; const uint32_t blck_1 = 64; - float tmp[32]; + __attribute__((aligned(128))) float tmp[64]; for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { - for (uint32_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1++) { + for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { const uint32_t i13 = (ir1 / (ne12 * ne1)); const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); @@ -1561,13 +1567,16 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0, const uint32_t i2 = i12; const uint32_t i3 = i13; - const uint8_t * restrict src0_row = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); + const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); const uint8_t * restrict src1_col = - (const uint8_t *) src1->data + (i11 + i12 * ne11 + i13 * ne12 * ne11) * src1_row_size; + (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - for (uint32_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0++) { - vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row + ir0 * src0_row_size, src1_col); + const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); + for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { + // Use nb01 stride for non-contiguous src0 support + const uint8_t * restrict src0_row = src0_base + ir0 * nb01; + vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col); } hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0); diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index d9c87da194..b320e2b4b2 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -154,7 +154,8 @@ class TensorNameMap: "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada "layers.{bid}.input_layernorm", # qwen3-embedding - "model.layers.{bid}.attention_layernorm" # apertus + "model.layers.{bid}.attention_layernorm", # apertus + "model.layers.{bid}.pre_attention_layernorm", # kormo ), # Attention norm 2 @@ -342,6 +343,7 @@ class TensorNameMap: "model.transformer.blocks.{bid}.ff_norm", # llada "layers.{bid}.post_attention_layernorm", # qwen3-embedding "model.layers.{bid}.feedforward_layernorm", # apertus + "model.layers.{bid}.pre_mlp_layernorm", # kormo ), # Pre feed-forward norm diff --git a/scripts/snapdragon/adb/run-mtmd.sh b/scripts/snapdragon/adb/run-mtmd.sh new file mode 100755 index 0000000000..91d868278a --- /dev/null +++ b/scripts/snapdragon/adb/run-mtmd.sh @@ -0,0 +1,65 @@ +#!/bin/sh +# + +# Basedir on device +basedir=/data/local/tmp/llama.cpp + +cli_opts= + +branch=. +[ "$B" != "" ] && branch=$B + +adbserial= +[ "$S" != "" ] && adbserial="-s $S" + +model="gemma-3-4b-it-Q4_0.gguf" +[ "$M" != "" ] && model="$M" + +mmproj="mmproj-F16.gguf" +[ "$MMPROJ" != "" ] && mmproj="$MMPROJ" + +image= +[ "$IMG" != "" ] && image="$IMG" + +device="HTP0" +[ "$D" != "" ] && device="$D" + +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" + +experimental="GGML_HEXAGON_EXPERIMENTAL=1" +[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E" + +sched= +[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" + +profile= +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" + +opmask= +[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" + +nhvx= +[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" + +ndev= +[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" + +# MTMD backend device for vision model (defaults to CPU if not set) +mtmd_backend= +[ "$MTMD_DEVICE" != "" ] && mtmd_backend="MTMD_BACKEND_DEVICE=$MTMD_DEVICE" + +set -x + +adb $adbserial shell " \ + cd $basedir; ulimit -c unlimited; \ + LD_LIBRARY_PATH=$basedir/$branch/lib \ + ADSP_LIBRARY_PATH=$basedir/$branch/lib \ + $verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \ + ./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \ + --mmproj $basedir/../gguf/$mmproj \ + --image $basedir/../gguf/$image \ + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ + --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \ + -ngl 99 --device $device -v $cli_opts $@ \ +" diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 4cbbca0925..3186242d60 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1561,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id]; + slot_info sinfo; + bool res = true; - res = res && state_read_meta(io, strm, cell_count, seq_id); - res = res && state_read_data(io, strm, cell_count); + res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id); + res = res && state_read_data(io, strm, cell_count, sinfo); if (!res) { if (seq_id == -1) { @@ -1702,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t } } -bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) { auto & cells = v_cells[strm]; auto & head = v_heads[strm]; @@ -1739,7 +1741,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 ubatch.seq_id[i] = &dest_seq_id; } - const auto sinfo = find_slot(ubatch, true); + sinfo = find_slot(ubatch, false); if (sinfo.empty()) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; @@ -1749,20 +1751,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 apply_ubatch(sinfo, ubatch); - const auto head_cur = sinfo.head(); + LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id); - // keep the head at the old position because we will read the KV data into it in state_read_data() - head = head_cur; - - LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id); - - // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head_cur + cell_count <= cells.size()); - GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]); - GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]); - GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); - GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); + // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values + GGML_ASSERT(sinfo.n_stream() == 1); + GGML_ASSERT(sinfo.idxs[0].size() == cell_count); + for (uint32_t i = 0; i < cell_count; ++i) { + const uint32_t idx = sinfo.idxs[0][i]; + GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]); + GGML_ASSERT(cells.seq_has(idx, dest_seq_id)); + } } else { // whole KV cache restore @@ -1795,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 } } + // Create contiguous slot_info for whole cache restore + sinfo.s0 = strm; + sinfo.s1 = strm; + sinfo.resize(1); + sinfo.strm[0] = strm; + sinfo.idxs[0].resize(cell_count); + for (uint32_t i = 0; i < cell_count; ++i) { + sinfo.idxs[0][i] = i; + } + head = 0; } return true; } -bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { +bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) { auto & cells = v_cells[strm]; - auto & head = v_heads[strm]; uint32_t v_trans; uint32_t n_layer; @@ -1853,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row); + } else { + // Slow path: scatter to non-contiguous positions + const void * src = io.read(cell_count * k_size_row); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = sinfo.idxs[0][i] * k_size_row; + ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row); + } + } } } @@ -1885,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row); + } else { + // Slow path: scatter to non-contiguous positions + const void * src = io.read(cell_count * v_size_row); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = sinfo.idxs[0][i] * v_size_row; + ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row); + } + } } } } else { @@ -1925,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells + const uint32_t h = sinfo.head(); + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (h + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } else { + // Slow path: scatter to non-contiguous positions + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const void * src = io.read(cell_count * v_size_el); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el); + } + } } } } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index bf7821c07c..1868f11857 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -72,6 +72,23 @@ public: void clear() { idxs.clear(); } + + // check if indices are contiguous starting from head() + bool is_contiguous() const { + if (idxs.empty() || idxs[0].empty()) { + return true; + } + if (idxs.size() > 1) { + return false; + } + const uint32_t h = idxs[0][0]; + for (size_t i = 0; i < idxs[0].size(); ++i) { + if (idxs[0][i] != h + i) { + return false; + } + } + return true; + } }; using slot_info_vec_t = std::vector; @@ -264,8 +281,8 @@ private: void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; - bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo); }; class llama_kv_cache_context : public llama_memory_context_i { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e4d2138056..050735afc0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3388,9 +3388,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e2cca66e48..7b01a2edfe 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1895,7 +1895,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { clean_spaces = false; } else if ( tokenizer_pre == "qwen2" || - tokenizer_pre == "deepseek-r1-qwen") { + tokenizer_pre == "deepseek-r1-qwen" || + tokenizer_pre == "kormo") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; } else if ( diff --git a/src/models/qwen2.cpp b/src/models/qwen2.cpp index 587a932426..3da4dea3c1 100644 --- a/src/models/qwen2.cpp +++ b/src/models/qwen2.cpp @@ -31,16 +31,25 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para { // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9ba559c8df..c3d9f9c324 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -222,6 +222,14 @@ llama_build_and_test(test-backend-ops.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") +# Test for state restore with fragmented KV cache +# Requires a model, uses same args pattern as test-thread-safety +if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf) +else() + llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-be.Q4_0.gguf) +endif() + if (NOT GGML_BACKEND_DL) # these tests use the backends directly and cannot be built with dynamic loading llama_build_and_test(test-barrier.cpp) diff --git a/tests/test-state-restore-fragmented.cpp b/tests/test-state-restore-fragmented.cpp new file mode 100644 index 0000000000..481b39d04c --- /dev/null +++ b/tests/test-state-restore-fragmented.cpp @@ -0,0 +1,122 @@ +// Test for state restore with fragmented KV cache +// This tests the fix for: https://github.com/ggml-org/llama.cpp/issues/17527 +// The issue was that state restore required contiguous KV cache slots, +// which fails when the cache is fragmented. +// +// The fix changes find_slot(ubatch, true) to find_slot(ubatch, false) +// in state_read_meta(), allowing non-contiguous slot allocation. + +#include "arg.h" +#include "common.h" +#include "llama.h" + +#include +#include +#include + +int main(int argc, char ** argv) { + common_params params; + + params.sampling.seed = 1234; + params.kv_unified = true; + params.n_parallel = 3; + params.n_ctx = 256; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + // init + common_init_result_ptr llama_init = common_init_from_params(params); + + llama_model * model = llama_init->model(); + llama_context * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + fprintf(stderr, "%s : failed to init\n", __func__); + return 1; + } + + GGML_UNUSED(model); + + // tokenize prompt + std::vector tokens(70, 1); + + // interleave the 3 sequences: + // 01201230123... + llama_batch batch = llama_batch_init(params.n_parallel*tokens.size(), 0, 1); + for (size_t i = 0; i < tokens.size(); i++) { + for (int s = 0; s < params.n_parallel; ++s) { + common_batch_add(batch, tokens[i], i, {s}, false); + } + } + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to decode seq 0\n", __func__); + return 1; + } + + fprintf(stderr, "%s : processed prompt on seq 0, 1, 2 (%zu tokens each)\n", __func__, tokens.size()); + + // Save state of seq 1 + std::vector seq_state(llama_state_seq_get_size(ctx, 1)); + const size_t ncopy = llama_state_seq_get_data(ctx, seq_state.data(), seq_state.size(), 1); + if (ncopy != seq_state.size()) { + fprintf(stderr, "%s : failed to save seq 1 state\n", __func__); + return 1; + } + fprintf(stderr, "%s : saved seq 1 state, %zu bytes\n", __func__, ncopy); + + // clear seq 1 to create a "hole" in the KV cache (fragmentation) + // 0.20.20.20.2.... + llama_memory_t mem = llama_get_memory(ctx); + llama_memory_seq_rm(mem, 1, -1, -1); + fprintf(stderr, "%s : cleared seq 1 to create fragmentation\n", __func__); + + // Now the cache has holes where seq 1 was + // This creates fragmentation - there's no contiguous block large enough + // for the seq 1 state if we only look for contiguous slots + + // Restore seq 1 state into seq 1 (should work with non-contiguous allocation) + // We use seq 1 since it's a valid sequence ID (0 to n_parallel-1) + // Before the fix, this would fail with "failed to find available cells in kv cache" + const size_t nset = llama_state_seq_set_data(ctx, seq_state.data(), seq_state.size(), 1); + if (nset != seq_state.size()) { + fprintf(stderr, "%s : FAILED to restore seq state into fragmented cache (got %zu, expected %zu)\n", + __func__, nset, seq_state.size()); + fprintf(stderr, "%s : This is the bug - state restore fails with fragmented KV cache\n", __func__); + llama_batch_free(batch); + return 1; + } + fprintf(stderr, "%s : restored state into seq 1, %zu bytes\n", __func__, nset); + + // Verify we can decode with the restored state + // Generate one token to verify the restored state is usable + auto sparams = llama_sampler_chain_default_params(); + llama_sampler * smpl = llama_sampler_chain_init(sparams); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed)); + + auto next_token = llama_sampler_sample(smpl, ctx, -1); + auto next_token_str = common_token_to_piece(ctx, next_token); + + common_batch_clear(batch); + common_batch_add(batch, next_token, (int)tokens.size(), {1}, true); + + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to decode with restored state\n", __func__); + llama_sampler_free(smpl); + llama_batch_free(batch); + return 1; + } + + fprintf(stderr, "%s : successfully decoded with restored state, generated: '%s'\n", __func__, next_token_str.c_str()); + fprintf(stderr, "%s : SUCCESS - state restore works with fragmented KV cache\n", __func__); + + llama_sampler_free(smpl); + llama_batch_free(batch); + + return 0; +} diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 1834d1d91f..036feff1c3 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte index 6fdd857214..908db5894b 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte @@ -1,6 +1,6 @@