kv-cache : fix M-RoPE checkpoints (#20132)

This commit is contained in:
Georgi Gerganov 2026-03-06 08:46:51 +02:00 committed by GitHub
parent f7db3f3789
commit 17a4258946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 3 deletions

View File

@ -394,11 +394,13 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
clear();
split_reset();
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
auto udata = std::make_shared<llama_ubatch::data_t>();
udata->token .resize(n_tokens);
udata->embd .clear();
udata->pos .resize(n_tokens);
udata->pos .resize(n_pos_all);
udata->n_seq_id .resize(n_tokens);
udata->seq_id .resize(n_tokens);
udata->seq_id_unq.resize(0);

View File

@ -1760,8 +1760,10 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
io.write(&pos, sizeof(pos));
io.write(&n_seq_id, sizeof(n_seq_id));
// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
if (hparams.n_pos_per_embd() > 1) {
const llama_kv_cell_ext ext = cells.ext_get(i);
io.write(&ext, sizeof(ext));
}
for (const auto & seq_id : seq_ids) {
io.write(&seq_id, sizeof(seq_id));
@ -1895,6 +1897,14 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
return false;
}
if (hparams.n_pos_per_embd() > 1) {
llama_kv_cell_ext ext;
io.read_to(&ext, sizeof(ext));
ubatch.pos[i + ubatch.n_tokens] = ext.y;
ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
}
// read the sequence id, but directly discard it - we will use dest_seq_id instead
{
llama_seq_id seq_id;