refactor: Use llama_memory_state_ptr for child states in hybrid memory state
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
7ba463b38c
commit
4ec4e6a801
|
|
@ -244,9 +244,9 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const {
|
const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const {
|
||||||
return state_attn.get();
|
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const {
|
const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const {
|
||||||
return state_recurrent.get();
|
return static_cast<const llama_kv_cache_recurrent_state *>(state_recurrent.get());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -145,6 +145,6 @@ private:
|
||||||
std::vector<uint32_t> heads_attn;
|
std::vector<uint32_t> heads_attn;
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
const llama_kv_cache_unified_state_ptr state_attn;
|
const llama_memory_state_ptr state_attn;
|
||||||
const llama_kv_cache_recurrent_state_ptr state_recurrent;
|
const llama_memory_state_ptr state_recurrent;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue