Internal changes

PiperOrigin-RevId: 835160876
This commit is contained in:
Krzysztof Rymski 2025-11-21 04:02:18 -08:00 committed by Copybara-Service
parent 5a500872b8
commit be30473dc6
5 changed files with 16 additions and 6 deletions

View File

@ -17,7 +17,9 @@
#include <stdint.h>
#include <algorithm>
#include <array>
#include <cmath>
#include <cstdlib>
#include <limits>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS

View File

@ -60,6 +60,7 @@ namespace gcpp {
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -51,7 +51,6 @@ KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache);
return copy;
}
@ -59,7 +58,9 @@ std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
std::vector<KVCachePtr> ptrs;
ptrs.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache});
ptrs.push_back(KVCachePtr{
.kv_cache = kv_caches[i].kv_cache,
});
}
return ptrs;
}

View File

@ -17,9 +17,11 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#include <stddef.h>
#include <optional>
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs
#include "util/basics.h" // BF16
#include "util/mat.h"
@ -31,12 +33,13 @@ using KV_t = float;
struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator);
// Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit.
KVCache Copy();
size_t SeqLen() const { return kv_cache.Rows(); }
size_t SeqLen() const {
return kv_cache.Rows();
}
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
@ -49,7 +52,9 @@ struct KVCache {
// A non-owning view of a KVCache.
struct KVCachePtr {
size_t SeqLen() const { return kv_cache.Rows(); }
size_t SeqLen() const {
return kv_cache.Rows();
}
MatPtrT<KV_t> kv_cache;
};

View File

@ -25,6 +25,7 @@
#include <cstdint>
#include <random>
#include <type_traits> // std::enable_if_t
#include <utility>
#include <vector>
#include "ops/matmul.h"