Internal changes

PiperOrigin-RevId: 837001762
This commit is contained in:
Krzysztof Rymski 2025-11-26 01:05:06 -08:00 committed by Copybara-Service
parent 8696f6dd17
commit c153d5255b
4 changed files with 36 additions and 6 deletions

View File

@ -518,6 +518,18 @@ cc_library(
], ],
) )
cc_test(
name = "kv_cache_test",
srcs = ["gemma/kv_cache_test.cc"],
deps = [
":configs",
":gemma_args",
":kv_cache",
":threading_context",
"//testing/base/public:gunit_main",
],
)
cc_library( cc_library(
name = "gemma_args", name = "gemma_args",
hdrs = ["gemma/gemma_args.h"], hdrs = ["gemma/gemma_args.h"],

View File

@ -51,7 +51,6 @@ KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_); KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache); CopyMat(kv_cache, copy.kv_cache);
return copy; return copy;
} }
@ -59,7 +58,9 @@ std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
std::vector<KVCachePtr> ptrs; std::vector<KVCachePtr> ptrs;
ptrs.reserve(kv_caches.size()); ptrs.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) { for (size_t i = 0; i < kv_caches.size(); ++i) {
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache}); ptrs.push_back(KVCachePtr{
.kv_cache = kv_caches[i].kv_cache,
});
} }
return ptrs; return ptrs;
} }

View File

@ -17,6 +17,8 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#include <stddef.h> #include <stddef.h>
#include <optional>
#include <vector> #include <vector>
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
@ -31,19 +33,23 @@ using KV_t = float;
// A non-owning view of a KVCache. // A non-owning view of a KVCache.
struct KVCachePtr { struct KVCachePtr {
bool IsEmpty() const { return kv_cache.Rows() == 0; } bool IsEmpty() const { return kv_cache.Rows() == 0; }
size_t SeqLen() const { return kv_cache.Rows(); } size_t SeqLen() const {
return kv_cache.Rows();
}
MatPtrT<KV_t> kv_cache; MatPtrT<KV_t> kv_cache;
}; };
struct KVCache { struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args, KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator); const Allocator& allocator);
// Returns a deep copy of the KVCache. Use explicit function instead of // Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit. // copy ctor to make the cost explicit.
KVCache Copy(); KVCache Copy();
size_t SeqLen() const { return kv_cache.Rows(); } size_t SeqLen() const {
return kv_cache.Rows();
}
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]

11
gemma/kv_cache_test.cc Normal file
View File

@ -0,0 +1,11 @@
#include "gemma/kv_cache.h"
#include "gtest/gtest.h"
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
#include "util/threading_context.h"
namespace gcpp {
namespace {
} // namespace
} // namespace gcpp