Internal changes

PiperOrigin-RevId: 835162918
This commit is contained in:
Krzysztof Rymski 2025-11-21 04:08:32 -08:00 committed by Copybara-Service
parent 5a500872b8
commit d6504d12a2
6 changed files with 20 additions and 8 deletions

View File

@ -17,7 +17,9 @@
#include <stdint.h> #include <stdint.h>
#include <algorithm> #include <algorithm>
#include <array>
#include <cmath> #include <cmath>
#include <cstdlib>
#include <limits> #include <limits>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS

View File

@ -60,6 +60,7 @@ namespace gcpp {
size_t layer_idx, const MatPtr& query_norm_scale, \ size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \ ThreadingContext& ctx); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -513,8 +513,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
} }
if (max_prompt_size > seq_len) { if (max_prompt_size > seq_len) {
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.", HWY_ABORT(
max_prompt_size); "max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least "
"that.",
max_prompt_size, seq_len);
} }
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);

View File

@ -51,7 +51,6 @@ KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_); KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache); CopyMat(kv_cache, copy.kv_cache);
return copy; return copy;
} }
@ -59,7 +58,9 @@ std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
std::vector<KVCachePtr> ptrs; std::vector<KVCachePtr> ptrs;
ptrs.reserve(kv_caches.size()); ptrs.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) { for (size_t i = 0; i < kv_caches.size(); ++i) {
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache}); ptrs.push_back(KVCachePtr{
.kv_cache = kv_caches[i].kv_cache,
});
} }
return ptrs; return ptrs;
} }

View File

@ -17,6 +17,8 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#include <stddef.h> #include <stddef.h>
#include <optional>
#include <vector> #include <vector>
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
@ -31,12 +33,13 @@ using KV_t = float;
struct KVCache { struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args, KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator); const Allocator& allocator);
// Returns a deep copy of the KVCache. Use explicit function instead of // Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit. // copy ctor to make the cost explicit.
KVCache Copy(); KVCache Copy();
size_t SeqLen() const { return kv_cache.Rows(); } size_t SeqLen() const {
return kv_cache.Rows();
}
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
@ -49,7 +52,9 @@ struct KVCache {
// A non-owning view of a KVCache. // A non-owning view of a KVCache.
struct KVCachePtr { struct KVCachePtr {
size_t SeqLen() const { return kv_cache.Rows(); } size_t SeqLen() const {
return kv_cache.Rows();
}
MatPtrT<KV_t> kv_cache; MatPtrT<KV_t> kv_cache;
}; };

View File

@ -25,6 +25,7 @@
#include <cstdint> #include <cstdint>
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include <utility>
#include <vector> #include <vector>
#include "ops/matmul.h" #include "ops/matmul.h"