Merge branch 'ggml-org:master' into power-law-sampler
This commit is contained in:
commit
85b6e52e39
|
|
@ -136,19 +136,11 @@ class ModelBase:
|
|||
self.remote_hf_model_id = remote_hf_model_id
|
||||
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
|
||||
self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {}
|
||||
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
|
||||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
|
||||
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
|
||||
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
|
||||
if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None:
|
||||
self.rope_parameters["rope_theta"] = rope_theta
|
||||
if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
|
||||
self.rope_parameters["rope_type"] = rope_type
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
|
||||
|
|
@ -765,6 +757,15 @@ class TextModel(ModelBase):
|
|||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {}
|
||||
|
||||
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
|
||||
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
|
||||
if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None:
|
||||
self.rope_parameters["rope_theta"] = rope_theta
|
||||
if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
|
||||
self.rope_parameters["rope_type"] = rope_type
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls):
|
||||
# can't use an abstract property, because overriding it without type errors
|
||||
|
|
@ -1203,6 +1204,9 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
|
||||
# ref: https://huggingface.co/MiniMaxAI/MiniMax-M2
|
||||
res = "minimax-m2"
|
||||
if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665":
|
||||
# ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer
|
||||
res = "kormo"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
|
@ -3398,7 +3402,7 @@ class QwenModel(TextModel):
|
|||
self._set_vocab_qwen()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration")
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM")
|
||||
class Qwen2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ models = [
|
|||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||
{"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
|
|
|||
|
|
@ -1976,9 +1976,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
|||
break;
|
||||
|
||||
case GGML_TYPE_F16:
|
||||
if (!opt_experimental) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -903,7 +903,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
|
|||
const float * restrict vy = (const float * restrict) y;
|
||||
|
||||
for (uint32_t i = 0; i < n; i++) {
|
||||
rsum += vx[i] * (__fp16) vy[i];
|
||||
rsum += (float)vx[i] * vy[i];
|
||||
}
|
||||
*s = rsum;
|
||||
return;
|
||||
|
|
@ -917,7 +917,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
|
|||
|
||||
// for some reason we need volatile here so that the compiler doesn't try anything funky
|
||||
volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
float r_sum_scalar = 0.0f;
|
||||
uint32_t i = 0;
|
||||
|
||||
for (i = 0; i < nv0; i++) {
|
||||
|
|
@ -926,31 +926,42 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
|
|||
HVX_Vector x = vx[i];
|
||||
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
|
||||
|
||||
HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
|
||||
HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
|
||||
//NOTE: need volatile here to prevent compiler optimization
|
||||
// Seem compiler cannot guarantee read-after-write??
|
||||
volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
|
||||
volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
|
||||
|
||||
HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
|
||||
}
|
||||
|
||||
if (nv1) {
|
||||
HVX_VectorPair yp = vy[i];
|
||||
// HVX_VectorPair yp = vy[i];
|
||||
|
||||
HVX_Vector x = vx[i];
|
||||
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
|
||||
// HVX_Vector x = vx[i];
|
||||
// HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
|
||||
|
||||
if (nv1 >= 32) {
|
||||
HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
|
||||
nv1 -= 32;
|
||||
}
|
||||
// if (nv1 >= 32) {
|
||||
// volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
|
||||
// rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
|
||||
// nv1 -= 32;
|
||||
// }
|
||||
|
||||
// rsum = hvx_vec_qf32_reduce_sum(rsum);
|
||||
|
||||
// if (nv1) {
|
||||
// volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
|
||||
// HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
|
||||
// rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
|
||||
// }
|
||||
|
||||
//process the remainder using scalar loop
|
||||
rsum = hvx_vec_qf32_reduce_sum(rsum);
|
||||
const __fp16 * restrict sx = (const __fp16 * restrict) x;
|
||||
const float * restrict sy = (const float * restrict) y;
|
||||
|
||||
if (nv1) {
|
||||
HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
|
||||
HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
|
||||
for (uint32_t i = nv0 * 64; i < n; i++) {
|
||||
r_sum_scalar += (float) sx[i] * sy[i];
|
||||
}
|
||||
|
||||
// hvx_vec_dump_fp16("X", x);
|
||||
|
|
@ -961,7 +972,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
|
|||
rsum = hvx_vec_qf32_reduce_sum(rsum);
|
||||
}
|
||||
|
||||
*s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum));
|
||||
*s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar;
|
||||
|
||||
# ifdef HTP_DEBUG
|
||||
{
|
||||
|
|
@ -1498,9 +1509,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
|
|||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = sizeof(__fp16) * ne00;
|
||||
const size_t src1_row_size = sizeof(float) * ne10;
|
||||
|
||||
assert(ne12 % ne02 == 0);
|
||||
assert(ne13 % ne03 == 0);
|
||||
|
||||
|
|
@ -1510,8 +1518,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
|
|||
// This is the size of the rest of the dimensions of the result
|
||||
const uint32_t nr1 = ne1 * ne2 * ne3;
|
||||
|
||||
uint32_t chunk_size = 64;
|
||||
|
||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||
uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||
uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
|
||||
|
|
@ -1544,11 +1550,11 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
|
|||
const uint32_t blck_0 = 64;
|
||||
const uint32_t blck_1 = 64;
|
||||
|
||||
float tmp[32];
|
||||
__attribute__((aligned(128))) float tmp[64];
|
||||
|
||||
for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||
for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||
for (uint32_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1++) {
|
||||
for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
|
||||
const uint32_t i13 = (ir1 / (ne12 * ne1));
|
||||
const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
|
||||
const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
|
||||
|
|
@ -1561,13 +1567,16 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
|
|||
const uint32_t i2 = i12;
|
||||
const uint32_t i3 = i13;
|
||||
|
||||
const uint8_t * restrict src0_row = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
|
||||
const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
|
||||
const uint8_t * restrict src1_col =
|
||||
(const uint8_t *) src1->data + (i11 + i12 * ne11 + i13 * ne12 * ne11) * src1_row_size;
|
||||
(const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
|
||||
float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
||||
|
||||
for (uint32_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0++) {
|
||||
vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row + ir0 * src0_row_size, src1_col);
|
||||
const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
|
||||
for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
|
||||
// Use nb01 stride for non-contiguous src0 support
|
||||
const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
|
||||
vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col);
|
||||
}
|
||||
|
||||
hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0);
|
||||
|
|
|
|||
|
|
@ -154,7 +154,8 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.operator_norm", # lfm2
|
||||
"model.transformer.blocks.{bid}.attn_norm", # llada
|
||||
"layers.{bid}.input_layernorm", # qwen3-embedding
|
||||
"model.layers.{bid}.attention_layernorm" # apertus
|
||||
"model.layers.{bid}.attention_layernorm", # apertus
|
||||
"model.layers.{bid}.pre_attention_layernorm", # kormo
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
|
|
@ -342,6 +343,7 @@ class TensorNameMap:
|
|||
"model.transformer.blocks.{bid}.ff_norm", # llada
|
||||
"layers.{bid}.post_attention_layernorm", # qwen3-embedding
|
||||
"model.layers.{bid}.feedforward_layernorm", # apertus
|
||||
"model.layers.{bid}.pre_mlp_layernorm", # kormo
|
||||
),
|
||||
|
||||
# Pre feed-forward norm
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/sh
|
||||
#
|
||||
|
||||
# Basedir on device
|
||||
basedir=/data/local/tmp/llama.cpp
|
||||
|
||||
cli_opts=
|
||||
|
||||
branch=.
|
||||
[ "$B" != "" ] && branch=$B
|
||||
|
||||
adbserial=
|
||||
[ "$S" != "" ] && adbserial="-s $S"
|
||||
|
||||
model="gemma-3-4b-it-Q4_0.gguf"
|
||||
[ "$M" != "" ] && model="$M"
|
||||
|
||||
mmproj="mmproj-F16.gguf"
|
||||
[ "$MMPROJ" != "" ] && mmproj="$MMPROJ"
|
||||
|
||||
image=
|
||||
[ "$IMG" != "" ] && image="$IMG"
|
||||
|
||||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V"
|
||||
|
||||
experimental="GGML_HEXAGON_EXPERIMENTAL=1"
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
sched=
|
||||
[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v"
|
||||
|
||||
profile=
|
||||
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1"
|
||||
|
||||
opmask=
|
||||
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
|
||||
|
||||
nhvx=
|
||||
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
|
||||
|
||||
ndev=
|
||||
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
|
||||
|
||||
# MTMD backend device for vision model (defaults to CPU if not set)
|
||||
mtmd_backend=
|
||||
[ "$MTMD_DEVICE" != "" ] && mtmd_backend="MTMD_BACKEND_DEVICE=$MTMD_DEVICE"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
|
||||
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--mmproj $basedir/../gguf/$mmproj \
|
||||
--image $basedir/../gguf/$image \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
|
||||
-ngl 99 --device $device -v $cli_opts $@ \
|
||||
"
|
||||
|
|
@ -1561,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
|
|||
|
||||
const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
|
||||
|
||||
slot_info sinfo;
|
||||
|
||||
bool res = true;
|
||||
res = res && state_read_meta(io, strm, cell_count, seq_id);
|
||||
res = res && state_read_data(io, strm, cell_count);
|
||||
res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
|
||||
res = res && state_read_data(io, strm, cell_count, sinfo);
|
||||
|
||||
if (!res) {
|
||||
if (seq_id == -1) {
|
||||
|
|
@ -1702,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
|||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
|
||||
auto & cells = v_cells[strm];
|
||||
auto & head = v_heads[strm];
|
||||
|
||||
|
|
@ -1739,7 +1741,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|||
ubatch.seq_id[i] = &dest_seq_id;
|
||||
}
|
||||
|
||||
const auto sinfo = find_slot(ubatch, true);
|
||||
sinfo = find_slot(ubatch, false);
|
||||
if (sinfo.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
|
|
@ -1749,20 +1751,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|||
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
||||
apply_ubatch(sinfo, ubatch);
|
||||
|
||||
const auto head_cur = sinfo.head();
|
||||
LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
|
||||
|
||||
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
||||
head = head_cur;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
|
||||
|
||||
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
||||
GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
|
||||
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
||||
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
||||
// DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
|
||||
GGML_ASSERT(sinfo.n_stream() == 1);
|
||||
GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const uint32_t idx = sinfo.idxs[0][i];
|
||||
GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
|
||||
GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
|
||||
}
|
||||
} else {
|
||||
// whole KV cache restore
|
||||
|
||||
|
|
@ -1795,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|||
}
|
||||
}
|
||||
|
||||
// Create contiguous slot_info for whole cache restore
|
||||
sinfo.s0 = strm;
|
||||
sinfo.s1 = strm;
|
||||
sinfo.resize(1);
|
||||
sinfo.strm[0] = strm;
|
||||
sinfo.idxs[0].resize(cell_count);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
sinfo.idxs[0][i] = i;
|
||||
}
|
||||
|
||||
head = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
||||
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
|
||||
auto & cells = v_cells[strm];
|
||||
auto & head = v_heads[strm];
|
||||
|
||||
uint32_t v_trans;
|
||||
uint32_t n_layer;
|
||||
|
|
@ -1853,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * k_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
|
||||
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1885,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the values for the whole cell range
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * v_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -1925,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
|||
}
|
||||
|
||||
if (cell_count) {
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells
|
||||
const uint32_t h = sinfo.head();
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
}
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const void * src = io.read(cell_count * v_size_el);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,6 +72,23 @@ public:
|
|||
void clear() {
|
||||
idxs.clear();
|
||||
}
|
||||
|
||||
// check if indices are contiguous starting from head()
|
||||
bool is_contiguous() const {
|
||||
if (idxs.empty() || idxs[0].empty()) {
|
||||
return true;
|
||||
}
|
||||
if (idxs.size() > 1) {
|
||||
return false;
|
||||
}
|
||||
const uint32_t h = idxs[0][0];
|
||||
for (size_t i = 0; i < idxs[0].size(); ++i) {
|
||||
if (idxs[0][i] != h + i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
|
@ -264,8 +281,8 @@ private:
|
|||
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
|
||||
};
|
||||
|
||||
class llama_kv_cache_context : public llama_memory_context_i {
|
||||
|
|
|
|||
|
|
@ -3388,9 +3388,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
|
||||
// optional bias tensors
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -1895,7 +1895,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "qwen2" ||
|
||||
tokenizer_pre == "deepseek-r1-qwen") {
|
||||
tokenizer_pre == "deepseek-r1-qwen" ||
|
||||
tokenizer_pre == "kormo") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
|
|
|
|||
|
|
@ -31,16 +31,25 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para
|
|||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
|
|
|||
|
|
@ -222,6 +222,14 @@ llama_build_and_test(test-backend-ops.cpp)
|
|||
llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
|
||||
llama_build_and_test(test-autorelease.cpp LABEL "model")
|
||||
|
||||
# Test for state restore with fragmented KV cache
|
||||
# Requires a model, uses same args pattern as test-thread-safety
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf)
|
||||
else()
|
||||
llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-be.Q4_0.gguf)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
# these tests use the backends directly and cannot be built with dynamic loading
|
||||
llama_build_and_test(test-barrier.cpp)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
// Test for state restore with fragmented KV cache
|
||||
// This tests the fix for: https://github.com/ggml-org/llama.cpp/issues/17527
|
||||
// The issue was that state restore required contiguous KV cache slots,
|
||||
// which fails when the cache is fragmented.
|
||||
//
|
||||
// The fix changes find_slot(ubatch, true) to find_slot(ubatch, false)
|
||||
// in state_read_meta(), allowing non-contiguous slot allocation.
|
||||
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.sampling.seed = 1234;
|
||||
params.kv_unified = true;
|
||||
params.n_parallel = 3;
|
||||
params.n_ctx = 256;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// init
|
||||
common_init_result_ptr llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init->model();
|
||||
llama_context * ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr || ctx == nullptr) {
|
||||
fprintf(stderr, "%s : failed to init\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
GGML_UNUSED(model);
|
||||
|
||||
// tokenize prompt
|
||||
std::vector<llama_token> tokens(70, 1);
|
||||
|
||||
// interleave the 3 sequences:
|
||||
// 01201230123...
|
||||
llama_batch batch = llama_batch_init(params.n_parallel*tokens.size(), 0, 1);
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
for (int s = 0; s < params.n_parallel; ++s) {
|
||||
common_batch_add(batch, tokens[i], i, {s}, false);
|
||||
}
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to decode seq 0\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : processed prompt on seq 0, 1, 2 (%zu tokens each)\n", __func__, tokens.size());
|
||||
|
||||
// Save state of seq 1
|
||||
std::vector<uint8_t> seq_state(llama_state_seq_get_size(ctx, 1));
|
||||
const size_t ncopy = llama_state_seq_get_data(ctx, seq_state.data(), seq_state.size(), 1);
|
||||
if (ncopy != seq_state.size()) {
|
||||
fprintf(stderr, "%s : failed to save seq 1 state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : saved seq 1 state, %zu bytes\n", __func__, ncopy);
|
||||
|
||||
// clear seq 1 to create a "hole" in the KV cache (fragmentation)
|
||||
// 0.20.20.20.2....
|
||||
llama_memory_t mem = llama_get_memory(ctx);
|
||||
llama_memory_seq_rm(mem, 1, -1, -1);
|
||||
fprintf(stderr, "%s : cleared seq 1 to create fragmentation\n", __func__);
|
||||
|
||||
// Now the cache has holes where seq 1 was
|
||||
// This creates fragmentation - there's no contiguous block large enough
|
||||
// for the seq 1 state if we only look for contiguous slots
|
||||
|
||||
// Restore seq 1 state into seq 1 (should work with non-contiguous allocation)
|
||||
// We use seq 1 since it's a valid sequence ID (0 to n_parallel-1)
|
||||
// Before the fix, this would fail with "failed to find available cells in kv cache"
|
||||
const size_t nset = llama_state_seq_set_data(ctx, seq_state.data(), seq_state.size(), 1);
|
||||
if (nset != seq_state.size()) {
|
||||
fprintf(stderr, "%s : FAILED to restore seq state into fragmented cache (got %zu, expected %zu)\n",
|
||||
__func__, nset, seq_state.size());
|
||||
fprintf(stderr, "%s : This is the bug - state restore fails with fragmented KV cache\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : restored state into seq 1, %zu bytes\n", __func__, nset);
|
||||
|
||||
// Verify we can decode with the restored state
|
||||
// Generate one token to verify the restored state is usable
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
|
||||
|
||||
auto next_token = llama_sampler_sample(smpl, ctx, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx, next_token);
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, (int)tokens.size(), {1}, true);
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to decode with restored state\n", __func__);
|
||||
llama_sampler_free(smpl);
|
||||
llama_batch_free(batch);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : successfully decoded with restored state, generated: '%s'\n", __func__, next_token_str.c_str());
|
||||
fprintf(stderr, "%s : SUCCESS - state restore works with fragmented KV cache\n", __func__);
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_batch_free(batch);
|
||||
|
||||
return 0;
|
||||
}
|
||||
Binary file not shown.
|
|
@ -1,6 +1,6 @@
|
|||
<script lang="ts">
|
||||
import { RemoveButton } from '$lib/components/app';
|
||||
import { getFileTypeLabel, getPreviewText, formatFileSize, isTextFile } from '$lib/utils';
|
||||
import { formatFileSize, getFileTypeLabel, getPreviewText, isTextFile } from '$lib/utils';
|
||||
import { AttachmentType } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { getPreviewText } from '$lib/utils/text';
|
||||
import ChatSidebarActions from './ChatSidebarActions.svelte';
|
||||
|
||||
const sidebar = Sidebar.useSidebar();
|
||||
|
|
@ -20,6 +21,9 @@
|
|||
let showEditDialog = $state(false);
|
||||
let selectedConversation = $state<DatabaseConversation | null>(null);
|
||||
let editedName = $state('');
|
||||
let selectedConversationNamePreview = $derived.by(() =>
|
||||
selectedConversation ? getPreviewText(selectedConversation.name) : ''
|
||||
);
|
||||
|
||||
let filteredConversations = $derived.by(() => {
|
||||
if (searchQuery.trim().length > 0) {
|
||||
|
|
@ -162,7 +166,7 @@
|
|||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description={selectedConversation
|
||||
? `Are you sure you want to delete "${selectedConversation.name}"? This action cannot be undone and will permanently remove all messages in this conversation.`
|
||||
? `Are you sure you want to delete "${selectedConversationNamePreview}"? This action cannot be undone and will permanently remove all messages in this conversation.`
|
||||
: ''}
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
|
|
|
|||
|
|
@ -504,6 +504,14 @@
|
|||
background: hsl(var(--muted) / 0.1);
|
||||
}
|
||||
|
||||
/* User message markdown should keep table borders visible on light primary backgrounds */
|
||||
div.markdown-user-content :global(table),
|
||||
div.markdown-user-content :global(th),
|
||||
div.markdown-user-content :global(td),
|
||||
div.markdown-user-content :global(.table-wrapper) {
|
||||
border-color: currentColor;
|
||||
}
|
||||
|
||||
/* Horizontal rules */
|
||||
div :global(hr) {
|
||||
border: none;
|
||||
|
|
@ -642,6 +650,21 @@
|
|||
background: var(--muted);
|
||||
}
|
||||
|
||||
/* Disable hover effects when rendering user messages */
|
||||
.markdown-user-content :global(a),
|
||||
.markdown-user-content :global(a:hover) {
|
||||
color: var(--primary-foreground);
|
||||
}
|
||||
|
||||
.markdown-user-content :global(table:hover) {
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
.markdown-user-content :global(th:hover),
|
||||
.markdown-user-content :global(td:hover) {
|
||||
background: inherit;
|
||||
}
|
||||
|
||||
/* Enhanced blockquotes */
|
||||
div :global(blockquote) {
|
||||
transition: all 0.2s ease;
|
||||
|
|
|
|||
|
|
@ -34,12 +34,3 @@ export function getFileTypeLabel(input: string | undefined): string {
|
|||
// Handle AttachmentType or other plain strings
|
||||
return input.toUpperCase();
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncates text content for preview display
|
||||
* @param content - The text content to truncate
|
||||
* @returns Truncated content with ellipsis if needed
|
||||
*/
|
||||
export function getPreviewText(content: string): string {
|
||||
return content.length > 150 ? content.substring(0, 150) + '...' : content;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ export { createMessageCountMap, getMessageCount } from './conversation-utils';
|
|||
export { copyToClipboard, copyCodeToClipboard } from './copy';
|
||||
|
||||
// File preview utilities
|
||||
export { getFileTypeLabel, getPreviewText } from './file-preview';
|
||||
export { getFileTypeLabel } from './file-preview';
|
||||
export { getPreviewText } from './text';
|
||||
|
||||
// File type utilities
|
||||
export {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
/**
|
||||
* Returns a shortened preview of the provided content capped at the given length.
|
||||
* Appends an ellipsis when the content exceeds the maximum.
|
||||
*/
|
||||
export function getPreviewText(content: string, max = 150): string {
|
||||
return content.length > max ? content.slice(0, max) + '...' : content;
|
||||
}
|
||||
Loading…
Reference in New Issue