Add KVCache.DeepCopy() . Will be useful for implementing sampling functionality like beam sampling, parallel sampling, CoT Decoding (à la https://arxiv.org/abs/2402.10200)

PiperOrigin-RevId: 725156316
This commit is contained in:
Apoorv Reddy 2025-02-10 04:09:54 -08:00 committed by Copybara-Service
parent 9b3e7ea8a2
commit 780e376023
2 changed files with 29 additions and 0 deletions

View File

@ -76,4 +76,30 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
return kv_cache;
}
KVCache KVCache::Copy(const ModelConfig& weights_config,
size_t prefill_tbatch_size) {
KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size);
const size_t size_cache_pos = weights_config.CachePosSize();
if (size_cache_pos != 0) {
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
kv_cache_copy.kv_cache.get());
}
const size_t num_griffin_layers = weights_config.NumLayersOfType(
LayerAttentionType::kGriffinRecurrentBlock);
if (num_griffin_layers > 0) {
if (conv1d_cache_size != 0) {
std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size,
kv_cache_copy.conv1d_cache.get());
}
if (rglru_cache_size != 0) {
std::copy(rglru_cache.get(),
rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]),
kv_cache_copy.rglru_cache.get());
}
}
return kv_cache_copy;
}
} // namespace gcpp

View File

@ -43,6 +43,9 @@ struct KVCache {
static KVCache Create(const ModelConfig& weights_config,
size_t prefill_tbatch_size);
// Returns a deep copy of the KVCache.
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
};
} // namespace gcpp