diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 4e94ad6..60ad5dd 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -76,4 +76,30 @@ KVCache KVCache::Create(const ModelConfig& weights_config, return kv_cache; } +KVCache KVCache::Copy(const ModelConfig& weights_config, + size_t prefill_tbatch_size) { + KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size); + + const size_t size_cache_pos = weights_config.CachePosSize(); + if (size_cache_pos != 0) { + std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len, + kv_cache_copy.kv_cache.get()); + } + + const size_t num_griffin_layers = weights_config.NumLayersOfType( + LayerAttentionType::kGriffinRecurrentBlock); + if (num_griffin_layers > 0) { + if (conv1d_cache_size != 0) { + std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size, + kv_cache_copy.conv1d_cache.get()); + } + if (rglru_cache_size != 0) { + std::copy(rglru_cache.get(), + rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]), + kv_cache_copy.rglru_cache.get()); + } + } + return kv_cache_copy; +} + } // namespace gcpp diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 69f9564..6052d0b 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -43,6 +43,9 @@ struct KVCache { static KVCache Create(const ModelConfig& weights_config, size_t prefill_tbatch_size); + + // Returns a deep copy of the KVCache. + KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size); }; } // namespace gcpp