Internal changes

PiperOrigin-RevId: 854171429
This commit is contained in:
Krzysztof Rymski 2026-01-09 06:35:05 -08:00 committed by Copybara-Service
parent 6d43d6ee19
commit 16a7ba2d6e
7 changed files with 50 additions and 2 deletions

View File

@ -523,6 +523,7 @@ cc_library(
":configs",
":gemma_args",
":mat",
"//compression:types",
"@highway//:hwy",
],
)
@ -575,6 +576,7 @@ cc_library(
":configs",
":mat",
":threading_context",
"//compression:types",
"//io",
"@highway//:hwy",
"@highway//:profiler",

View File

@ -35,6 +35,7 @@
namespace gcpp {
typedef std::vector<float, hwy::AlignedAllocator<float>> AlignedFloatVector;
typedef std::vector<BF16, hwy::AlignedAllocator<BF16>> AlignedBF16Vector;
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.

View File

@ -22,8 +22,10 @@
#include <stdio.h>
#include <functional>
#include <optional>
#include <string>
#include "compression/types.h"
#include "gemma/configs.h"
#include "io/io.h" // Path
#include "util/args.h" // IWYU pragma: export

View File

@ -16,8 +16,12 @@
#include "gemma/kv_cache.h"
#include <stddef.h>
#include <algorithm>
#include <utility>
#include <vector>
#include "compression/types.h"
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
#include "util/mat.h" // ZeroInit

View File

@ -19,12 +19,14 @@
#include <stddef.h>
#include <optional>
#include <utility>
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs
#include "util/basics.h" // BF16
#include "util/mat.h"
#include "hwy/base.h"
namespace gcpp {

View File

@ -35,8 +35,13 @@ TEST(KVCacheTest, KVCacheToPtrs) {
std::vector<KVCachePtr> ptrs = ToKVCachePtrs({caches.data(), caches.size()});
ASSERT_EQ(ptrs.size(), 2);
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
if (caches[0].IsTiled()) {
EXPECT_EQ(ptrs[0].cache, &caches[0]);
EXPECT_EQ(ptrs[1].cache, &caches[1]);
} else {
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
}
}
} // namespace

View File

@ -469,6 +469,38 @@ decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
}
}
// Calls 'func' with a span of MatPtrT<T> for all elements in `base`.
// T is dynamic type, read from base. It is assumed that all elements in `base`
// have the same type.
template <class Func, typename... Args>
decltype(auto) CallUpcastedKVs(hwy::Span<const MatPtr> base, const Func& func,
Args&&... args) {
Type type = base[0].GetType();
for ([[maybe_unused]] auto&& mat : base) {
HWY_DASSERT(mat.GetType() == type);
}
auto convert_to_matptr_t = [&base]<typename T>() {
std::vector<MatPtrT<T>> matptrs;
matptrs.reserve(base.size());
for (auto&& mat : base) {
matptrs.emplace_back(mat);
}
return matptrs;
};
if (type == Type::kF32) {
auto matptrs = convert_to_matptr_t.template operator()<float>();
hwy::Span<const MatPtrT<float>> matptrs_span(matptrs.data(),
matptrs.size());
return func(matptrs_span, std::forward<Args>(args)...);
} else if (type == Type::kBF16) {
auto matptrs = convert_to_matptr_t.template operator()<BF16>();
hwy::Span<const MatPtrT<BF16>> matptrs_span(matptrs.data(), matptrs.size());
return func(matptrs_span, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(type));
}
}
void CopyMat(const MatPtr& from, MatPtr& to);
void ZeroInit(MatPtr& mat);