mirror of https://github.com/google/gemma.cpp.git
Add KVCache.DeepCopy() . Will be useful for implementing sampling functionality like beam sampling, parallel sampling, CoT Decoding (à la https://arxiv.org/abs/2402.10200)
PiperOrigin-RevId: 725156316
This commit is contained in:
parent
9b3e7ea8a2
commit
780e376023
|
|
@ -76,4 +76,30 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
|
|||
return kv_cache;
|
||||
}
|
||||
|
||||
KVCache KVCache::Copy(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size) {
|
||||
KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size);
|
||||
|
||||
const size_t size_cache_pos = weights_config.CachePosSize();
|
||||
if (size_cache_pos != 0) {
|
||||
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
|
||||
kv_cache_copy.kv_cache.get());
|
||||
}
|
||||
|
||||
const size_t num_griffin_layers = weights_config.NumLayersOfType(
|
||||
LayerAttentionType::kGriffinRecurrentBlock);
|
||||
if (num_griffin_layers > 0) {
|
||||
if (conv1d_cache_size != 0) {
|
||||
std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size,
|
||||
kv_cache_copy.conv1d_cache.get());
|
||||
}
|
||||
if (rglru_cache_size != 0) {
|
||||
std::copy(rglru_cache.get(),
|
||||
rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]),
|
||||
kv_cache_copy.rglru_cache.get());
|
||||
}
|
||||
}
|
||||
return kv_cache_copy;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -43,6 +43,9 @@ struct KVCache {
|
|||
|
||||
static KVCache Create(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size);
|
||||
|
||||
// Returns a deep copy of the KVCache.
|
||||
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue