mirror of https://github.com/google/gemma.cpp.git
parent
6d43d6ee19
commit
16a7ba2d6e
|
|
@ -523,6 +523,7 @@ cc_library(
|
||||||
":configs",
|
":configs",
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":mat",
|
":mat",
|
||||||
|
"//compression:types",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -575,6 +576,7 @@ cc_library(
|
||||||
":configs",
|
":configs",
|
||||||
":mat",
|
":mat",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
|
"//compression:types",
|
||||||
"//io",
|
"//io",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
typedef std::vector<float, hwy::AlignedAllocator<float>> AlignedFloatVector;
|
typedef std::vector<float, hwy::AlignedAllocator<float>> AlignedFloatVector;
|
||||||
|
typedef std::vector<BF16, hwy::AlignedAllocator<BF16>> AlignedBF16Vector;
|
||||||
|
|
||||||
// Returns the scale value to use for the query in the attention computation.
|
// Returns the scale value to use for the query in the attention computation.
|
||||||
// Also called by ops_test.
|
// Also called by ops_test.
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,10 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "compression/types.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "io/io.h" // Path
|
#include "io/io.h" // Path
|
||||||
#include "util/args.h" // IWYU pragma: export
|
#include "util/args.h" // IWYU pragma: export
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,12 @@
|
||||||
#include "gemma/kv_cache.h"
|
#include "gemma/kv_cache.h"
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/types.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h"
|
||||||
#include "util/mat.h" // ZeroInit
|
#include "util/mat.h" // ZeroInit
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,14 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig
|
||||||
#include "gemma/gemma_args.h" // InferenceArgs
|
#include "gemma/gemma_args.h" // InferenceArgs
|
||||||
#include "util/basics.h" // BF16
|
#include "util/basics.h" // BF16
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,8 +35,13 @@ TEST(KVCacheTest, KVCacheToPtrs) {
|
||||||
|
|
||||||
std::vector<KVCachePtr> ptrs = ToKVCachePtrs({caches.data(), caches.size()});
|
std::vector<KVCachePtr> ptrs = ToKVCachePtrs({caches.data(), caches.size()});
|
||||||
ASSERT_EQ(ptrs.size(), 2);
|
ASSERT_EQ(ptrs.size(), 2);
|
||||||
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
|
if (caches[0].IsTiled()) {
|
||||||
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
|
EXPECT_EQ(ptrs[0].cache, &caches[0]);
|
||||||
|
EXPECT_EQ(ptrs[1].cache, &caches[1]);
|
||||||
|
} else {
|
||||||
|
EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0));
|
||||||
|
EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
32
util/mat.h
32
util/mat.h
|
|
@ -469,6 +469,38 @@ decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calls 'func' with a span of MatPtrT<T> for all elements in `base`.
|
||||||
|
// T is dynamic type, read from base. It is assumed that all elements in `base`
|
||||||
|
// have the same type.
|
||||||
|
template <class Func, typename... Args>
|
||||||
|
decltype(auto) CallUpcastedKVs(hwy::Span<const MatPtr> base, const Func& func,
|
||||||
|
Args&&... args) {
|
||||||
|
Type type = base[0].GetType();
|
||||||
|
for ([[maybe_unused]] auto&& mat : base) {
|
||||||
|
HWY_DASSERT(mat.GetType() == type);
|
||||||
|
}
|
||||||
|
auto convert_to_matptr_t = [&base]<typename T>() {
|
||||||
|
std::vector<MatPtrT<T>> matptrs;
|
||||||
|
matptrs.reserve(base.size());
|
||||||
|
for (auto&& mat : base) {
|
||||||
|
matptrs.emplace_back(mat);
|
||||||
|
}
|
||||||
|
return matptrs;
|
||||||
|
};
|
||||||
|
if (type == Type::kF32) {
|
||||||
|
auto matptrs = convert_to_matptr_t.template operator()<float>();
|
||||||
|
hwy::Span<const MatPtrT<float>> matptrs_span(matptrs.data(),
|
||||||
|
matptrs.size());
|
||||||
|
return func(matptrs_span, std::forward<Args>(args)...);
|
||||||
|
} else if (type == Type::kBF16) {
|
||||||
|
auto matptrs = convert_to_matptr_t.template operator()<BF16>();
|
||||||
|
hwy::Span<const MatPtrT<BF16>> matptrs_span(matptrs.data(), matptrs.size());
|
||||||
|
return func(matptrs_span, std::forward<Args>(args)...);
|
||||||
|
} else {
|
||||||
|
HWY_ABORT("Unhandled type %s.", TypeName(type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void CopyMat(const MatPtr& from, MatPtr& to);
|
void CopyMat(const MatPtr& from, MatPtr& to);
|
||||||
void ZeroInit(MatPtr& mat);
|
void ZeroInit(MatPtr& mat);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue