mirror of https://github.com/google/gemma.cpp.git
parent
6d43d6ee19
commit
16a7ba2d6e
|
|
@ -523,6 +523,7 @@ cc_library(
|
|||
":configs",
|
||||
":gemma_args",
|
||||
":mat",
|
||||
"//compression:types",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
|
@ -575,6 +576,7 @@ cc_library(
|
|||
":configs",
|
||||
":mat",
|
||||
":threading_context",
|
||||
"//compression:types",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@
|
|||
namespace gcpp {
|
||||
|
||||
typedef std::vector<float, hwy::AlignedAllocator<float>> AlignedFloatVector;
|
||||
typedef std::vector<BF16, hwy::AlignedAllocator<BF16>> AlignedBF16Vector;
|
||||
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
// Also called by ops_test.
|
||||
|
|
|
|||
|
|
@ -22,8 +22,10 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "io/io.h" // Path
|
||||
#include "util/args.h" // IWYU pragma: export
|
||||
|
|
|
|||
|
|
@ -16,8 +16,12 @@
|
|||
#include "gemma/kv_cache.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "util/mat.h" // ZeroInit
|
||||
|
|
|
|||
|
|
@ -19,12 +19,14 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h" // InferenceArgs
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
|
|||
|
|
@ -35,8 +35,13 @@ TEST(KVCacheTest, KVCacheToPtrs) {
|
|||
|
||||
std::vector<KVCachePtr> ptrs = ToKVCachePtrs({caches.data(), caches.size()});
|
||||
ASSERT_EQ(ptrs.size(), 2);
|
||||
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
|
||||
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
|
||||
if (caches[0].IsTiled()) {
|
||||
EXPECT_EQ(ptrs[0].cache, &caches[0]);
|
||||
EXPECT_EQ(ptrs[1].cache, &caches[1]);
|
||||
} else {
|
||||
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
|
||||
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
32
util/mat.h
32
util/mat.h
|
|
@ -469,6 +469,38 @@ decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
|
|||
}
|
||||
}
|
||||
|
||||
// Calls 'func' with a span of MatPtrT<T> for all elements in `base`.
|
||||
// T is dynamic type, read from base. It is assumed that all elements in `base`
|
||||
// have the same type.
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcastedKVs(hwy::Span<const MatPtr> base, const Func& func,
|
||||
Args&&... args) {
|
||||
Type type = base[0].GetType();
|
||||
for ([[maybe_unused]] auto&& mat : base) {
|
||||
HWY_DASSERT(mat.GetType() == type);
|
||||
}
|
||||
auto convert_to_matptr_t = [&base]<typename T>() {
|
||||
std::vector<MatPtrT<T>> matptrs;
|
||||
matptrs.reserve(base.size());
|
||||
for (auto&& mat : base) {
|
||||
matptrs.emplace_back(mat);
|
||||
}
|
||||
return matptrs;
|
||||
};
|
||||
if (type == Type::kF32) {
|
||||
auto matptrs = convert_to_matptr_t.template operator()<float>();
|
||||
hwy::Span<const MatPtrT<float>> matptrs_span(matptrs.data(),
|
||||
matptrs.size());
|
||||
return func(matptrs_span, std::forward<Args>(args)...);
|
||||
} else if (type == Type::kBF16) {
|
||||
auto matptrs = convert_to_matptr_t.template operator()<BF16>();
|
||||
hwy::Span<const MatPtrT<BF16>> matptrs_span(matptrs.data(), matptrs.size());
|
||||
return func(matptrs_span, std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(type));
|
||||
}
|
||||
}
|
||||
|
||||
void CopyMat(const MatPtr& from, MatPtr& to);
|
||||
void ZeroInit(MatPtr& mat);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue