From 780e37602392e4635c17de62cf2f1c0ab1e3e4ea Mon Sep 17 00:00:00 2001 From: Apoorv Reddy Date: Mon, 10 Feb 2025 04:09:54 -0800 Subject: [PATCH] =?UTF-8?q?Add=20KVCache.DeepCopy()=20.=20Will=20be=20usef?= =?UTF-8?q?ul=20for=20implementing=20sampling=20functionality=20like=20bea?= =?UTF-8?q?m=20sampling,=20parallel=20sampling,=20CoT=20Decoding=20(=C3=A0?= =?UTF-8?q?=20la=20https://arxiv.org/abs/2402.10200)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PiperOrigin-RevId: 725156316 --- gemma/kv_cache.cc | 26 ++++++++++++++++++++++++++ gemma/kv_cache.h | 3 +++ 2 files changed, 29 insertions(+) diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 4e94ad6..60ad5dd 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -76,4 +76,30 @@ KVCache KVCache::Create(const ModelConfig& weights_config, return kv_cache; } +KVCache KVCache::Copy(const ModelConfig& weights_config, + size_t prefill_tbatch_size) { + KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size); + + const size_t size_cache_pos = weights_config.CachePosSize(); + if (size_cache_pos != 0) { + std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len, + kv_cache_copy.kv_cache.get()); + } + + const size_t num_griffin_layers = weights_config.NumLayersOfType( + LayerAttentionType::kGriffinRecurrentBlock); + if (num_griffin_layers > 0) { + if (conv1d_cache_size != 0) { + std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size, + kv_cache_copy.conv1d_cache.get()); + } + if (rglru_cache_size != 0) { + std::copy(rglru_cache.get(), + rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]), + kv_cache_copy.rglru_cache.get()); + } + } + return kv_cache_copy; +} + } // namespace gcpp diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 69f9564..6052d0b 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -43,6 +43,9 @@ struct KVCache { static KVCache Create(const ModelConfig& weights_config, size_t prefill_tbatch_size); + + // Returns a deep copy of the KVCache. + KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size); }; } // namespace gcpp