parent
08011c2ca1
commit
2645a7d9a9
|
|
@ -3657,6 +3657,40 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross(
|
|||
return inp_kq_mask_cross;
|
||||
}
|
||||
|
||||
// state save/load
|
||||
|
||||
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
||||
llama_context::state_get_data(io);
|
||||
|
||||
kv_self.state_write(io);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
|
||||
llama_context::state_set_data(io);
|
||||
|
||||
kv_self.state_read(io);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
llama_context::state_seq_get_data(io, seq_id);
|
||||
|
||||
kv_self.state_write(io, seq_id);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
llama_context::state_seq_set_data(io, seq_id);
|
||||
|
||||
kv_self.state_read(io, seq_id);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
//
|
||||
// llama_context_recurrent
|
||||
//
|
||||
|
|
@ -4527,7 +4561,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
|
|||
|
||||
// state save/load
|
||||
|
||||
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
||||
size_t llama_context_recurrent::state_get_data(llama_io_write_i & io) {
|
||||
llama_context::state_get_data(io);
|
||||
|
||||
kv_self.state_write(io);
|
||||
|
|
@ -4535,7 +4569,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
|||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
|
||||
size_t llama_context_recurrent::state_set_data(llama_io_read_i & io) {
|
||||
llama_context::state_set_data(io);
|
||||
|
||||
kv_self.state_read(io);
|
||||
|
|
@ -4543,7 +4577,7 @@ size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
|
|||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
size_t llama_context_recurrent::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
llama_context::state_seq_get_data(io, seq_id);
|
||||
|
||||
kv_self.state_write(io, seq_id);
|
||||
|
|
@ -4551,7 +4585,7 @@ size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_se
|
|||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
llama_context::state_seq_set_data(io, seq_id);
|
||||
|
||||
kv_self.state_read(io, seq_id);
|
||||
|
|
|
|||
|
|
@ -525,6 +525,12 @@ public:
|
|||
bool worst_case) override;
|
||||
|
||||
protected:
|
||||
virtual size_t state_get_data(llama_io_write_i & io) override;
|
||||
virtual size_t state_set_data(llama_io_read_i & io) override;
|
||||
|
||||
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
|
||||
|
|
|
|||
Loading…
Reference in New Issue