diff --git a/BUILD.bazel b/BUILD.bazel index 2a33707..fd85acb 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -518,6 +518,18 @@ cc_library( ], ) +cc_test( + name = "kv_cache_test", + srcs = ["gemma/kv_cache_test.cc"], + deps = [ + ":configs", + ":gemma_args", + ":kv_cache", + ":threading_context", + "//testing/base/public:gunit_main", + ], +) + cc_library( name = "gemma_args", hdrs = ["gemma/gemma_args.h"], diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index ded8df5..8948644 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -51,7 +51,6 @@ KVCache KVCache::Copy() { KVCache copy(kv_cache.Extents(), allocator_); CopyMat(kv_cache, copy.kv_cache); - return copy; } @@ -59,7 +58,9 @@ std::vector ToKVCachePtrs(const hwy::Span& kv_caches) { std::vector ptrs; ptrs.reserve(kv_caches.size()); for (size_t i = 0; i < kv_caches.size(); ++i) { - ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache}); + ptrs.push_back(KVCachePtr{ + .kv_cache = kv_caches[i].kv_cache, + }); } return ptrs; } diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 3697116..1d78f7d 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -17,9 +17,11 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #include + +#include #include -#include "gemma/configs.h" // ModelConfig +#include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // InferenceArgs #include "util/basics.h" // BF16 #include "util/mat.h" @@ -31,19 +33,23 @@ using KV_t = float; // A non-owning view of a KVCache. struct KVCachePtr { bool IsEmpty() const { return kv_cache.Rows() == 0; } - size_t SeqLen() const { return kv_cache.Rows(); } + size_t SeqLen() const { + return kv_cache.Rows(); + } + MatPtrT kv_cache; }; struct KVCache { KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const Allocator& allocator); - // Returns a deep copy of the KVCache. Use explicit function instead of // copy ctor to make the cost explicit. KVCache Copy(); - size_t SeqLen() const { return kv_cache.Rows(); } + size_t SeqLen() const { + return kv_cache.Rows(); + } MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] diff --git a/gemma/kv_cache_test.cc b/gemma/kv_cache_test.cc new file mode 100644 index 0000000..2849cfe --- /dev/null +++ b/gemma/kv_cache_test.cc @@ -0,0 +1,11 @@ +#include "gemma/kv_cache.h" + +#include "gtest/gtest.h" +#include "gemma/configs.h" +#include "gemma/gemma_args.h" +#include "util/threading_context.h" +namespace gcpp { +namespace { + +} // namespace +} // namespace gcpp