This commit is contained in:
copybara-service[bot] 2025-11-24 12:16:21 +05:30 committed by GitHub
commit 670281d31e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 16 additions and 6 deletions

View File

@ -17,7 +17,9 @@
#include <stdint.h> #include <stdint.h>
#include <algorithm> #include <algorithm>
#include <array>
#include <cmath> #include <cmath>
#include <cstdlib>
#include <limits> #include <limits>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS

View File

@ -60,6 +60,7 @@ namespace gcpp {
size_t layer_idx, const MatPtr& query_norm_scale, \ size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \ ThreadingContext& ctx); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -51,7 +51,6 @@ KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_); KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache); CopyMat(kv_cache, copy.kv_cache);
return copy; return copy;
} }
@ -59,7 +58,9 @@ std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
std::vector<KVCachePtr> ptrs; std::vector<KVCachePtr> ptrs;
ptrs.reserve(kv_caches.size()); ptrs.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) { for (size_t i = 0; i < kv_caches.size(); ++i) {
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache}); ptrs.push_back(KVCachePtr{
.kv_cache = kv_caches[i].kv_cache,
});
} }
return ptrs; return ptrs;
} }

View File

@ -17,9 +17,11 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#include <stddef.h> #include <stddef.h>
#include <optional>
#include <vector> #include <vector>
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs #include "gemma/gemma_args.h" // InferenceArgs
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" #include "util/mat.h"
@ -31,12 +33,13 @@ using KV_t = float;
struct KVCache { struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args, KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator); const Allocator& allocator);
// Returns a deep copy of the KVCache. Use explicit function instead of // Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit. // copy ctor to make the cost explicit.
KVCache Copy(); KVCache Copy();
size_t SeqLen() const { return kv_cache.Rows(); } size_t SeqLen() const {
return kv_cache.Rows();
}
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
@ -49,7 +52,9 @@ struct KVCache {
// A non-owning view of a KVCache. // A non-owning view of a KVCache.
struct KVCachePtr { struct KVCachePtr {
size_t SeqLen() const { return kv_cache.Rows(); } size_t SeqLen() const {
return kv_cache.Rows();
}
MatPtrT<KV_t> kv_cache; MatPtrT<KV_t> kv_cache;
}; };

View File

@ -25,6 +25,7 @@
#include <cstdint> #include <cstdint>
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include <utility>
#include <vector> #include <vector>
#include "ops/matmul.h" #include "ops/matmul.h"