mirror of https://github.com/google/gemma.cpp.git
Add KVCache.DeepCopy() . Will be useful for implementing sampling functionality like beam sampling, parallel sampling, CoT Decoding (à la https://arxiv.org/abs/2402.10200)
PiperOrigin-RevId: 725156316
This commit is contained in:
parent
9b3e7ea8a2
commit
780e376023
|
|
@ -76,4 +76,30 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
|
||||||
return kv_cache;
|
return kv_cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KVCache KVCache::Copy(const ModelConfig& weights_config,
|
||||||
|
size_t prefill_tbatch_size) {
|
||||||
|
KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size);
|
||||||
|
|
||||||
|
const size_t size_cache_pos = weights_config.CachePosSize();
|
||||||
|
if (size_cache_pos != 0) {
|
||||||
|
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
|
||||||
|
kv_cache_copy.kv_cache.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t num_griffin_layers = weights_config.NumLayersOfType(
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock);
|
||||||
|
if (num_griffin_layers > 0) {
|
||||||
|
if (conv1d_cache_size != 0) {
|
||||||
|
std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size,
|
||||||
|
kv_cache_copy.conv1d_cache.get());
|
||||||
|
}
|
||||||
|
if (rglru_cache_size != 0) {
|
||||||
|
std::copy(rglru_cache.get(),
|
||||||
|
rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]),
|
||||||
|
kv_cache_copy.rglru_cache.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kv_cache_copy;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,9 @@ struct KVCache {
|
||||||
|
|
||||||
static KVCache Create(const ModelConfig& weights_config,
|
static KVCache Create(const ModelConfig& weights_config,
|
||||||
size_t prefill_tbatch_size);
|
size_t prefill_tbatch_size);
|
||||||
|
|
||||||
|
// Returns a deep copy of the KVCache.
|
||||||
|
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue