Internal changes

PiperOrigin-RevId: 837001762
This commit is contained in:
Krzysztof Rymski 2025-11-26 01:05:06 -08:00 committed by Copybara-Service
parent 8696f6dd17
commit c153d5255b
4 changed files with 36 additions and 6 deletions

View File

@ -518,6 +518,18 @@ cc_library(
],
)
cc_test(
name = "kv_cache_test",
srcs = ["gemma/kv_cache_test.cc"],
deps = [
":configs",
":gemma_args",
":kv_cache",
":threading_context",
"//testing/base/public:gunit_main",
],
)
cc_library(
name = "gemma_args",
hdrs = ["gemma/gemma_args.h"],

View File

@ -51,7 +51,6 @@ KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache);
return copy;
}
@ -59,7 +58,9 @@ std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
std::vector<KVCachePtr> ptrs;
ptrs.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache});
ptrs.push_back(KVCachePtr{
.kv_cache = kv_caches[i].kv_cache,
});
}
return ptrs;
}

View File

@ -17,9 +17,11 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#include <stddef.h>
#include <optional>
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs
#include "util/basics.h" // BF16
#include "util/mat.h"
@ -31,19 +33,23 @@ using KV_t = float;
// A non-owning view of a KVCache.
struct KVCachePtr {
bool IsEmpty() const { return kv_cache.Rows() == 0; }
size_t SeqLen() const { return kv_cache.Rows(); }
size_t SeqLen() const {
return kv_cache.Rows();
}
MatPtrT<KV_t> kv_cache;
};
struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator);
// Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit.
KVCache Copy();
size_t SeqLen() const { return kv_cache.Rows(); }
size_t SeqLen() const {
return kv_cache.Rows();
}
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]

11
gemma/kv_cache_test.cc Normal file
View File

@ -0,0 +1,11 @@
#include "gemma/kv_cache.h"
#include "gtest/gtest.h"
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
#include "util/threading_context.h"
namespace gcpp {
namespace {
} // namespace
} // namespace gcpp