mirror of https://github.com/google/gemma.cpp.git
Merge branch 'dev' into refactor-tidy
This commit is contained in:
commit
4aa8d0584e
|
|
@ -1,4 +1,5 @@
|
||||||
FormatStyle: file
|
FormatStyle: file
|
||||||
|
WarningsAsErrors: "*"
|
||||||
Checks: "-*,\
|
Checks: "-*,\
|
||||||
abseil-*,\
|
abseil-*,\
|
||||||
-abseil-string-find-startswith,\
|
-abseil-string-find-startswith,\
|
||||||
|
|
@ -204,3 +205,6 @@ Checks: "-*,\
|
||||||
-readability-uppercase-literal-suffix,\
|
-readability-uppercase-literal-suffix,\
|
||||||
-readability-use-anyofallof
|
-readability-use-anyofallof
|
||||||
"
|
"
|
||||||
|
CheckOptions:
|
||||||
|
- { key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase }
|
||||||
|
- { key: readability-identifier-naming.ConstexprVariablePrefix, value: k }
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
":args",
|
||||||
":transformer_ops",
|
":transformer_ops",
|
||||||
|
# "//base",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:matvec",
|
"@hwy//:matvec",
|
||||||
|
|
@ -88,6 +89,7 @@ cc_binary(
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
|
# "//base",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:nanobenchmark",
|
"@hwy//:nanobenchmark",
|
||||||
|
|
|
||||||
|
|
@ -341,7 +341,7 @@ BlobError BlobReader::Open(const char* filename) {
|
||||||
#endif
|
#endif
|
||||||
if (fd_ < 0) return __LINE__;
|
if (fd_ < 0) return __LINE__;
|
||||||
|
|
||||||
#if HWY_OS_LINUX
|
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||||
// Doubles the readahead window, which seems slightly faster when cached.
|
// Doubles the readahead window, which seems slightly faster when cached.
|
||||||
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
|
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
10
gemma.cc
10
gemma.cc
|
|
@ -19,18 +19,16 @@
|
||||||
// which we pass the filename via macro 'argument'.
|
// which we pass the filename via macro 'argument'.
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// Must come after foreach_target.h to avoid redefinition errors.
|
// Must come after foreach_target.h to avoid redefinition errors.
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
#include "util/args.h" // Path
|
|
||||||
|
|
||||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||||
// compile pass, whereas we want this defined in the first.
|
// compile pass, whereas we want this defined in the first.
|
||||||
|
|
@ -766,10 +764,10 @@ GemmaImpl<Config>::GemmaImpl(
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
|
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
|
||||||
hwy::ThreadPool& pool)
|
hwy::ThreadPool& pool)
|
||||||
: compressed_weights(std::move(compressed_weights)),
|
: tokenizer(std::move(tokenizer)),
|
||||||
|
compressed_weights(std::move(compressed_weights)),
|
||||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
|
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
||||||
tokenizer(std::move(tokenizer)) {}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||||
|
|
|
||||||
48
ops.h
48
ops.h
|
|
@ -340,11 +340,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
||||||
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
||||||
const float* HWY_RESTRICT a, size_t size) {
|
const float* HWY_RESTRICT a, size_t size) {
|
||||||
float total = 0.f;
|
const hn::ScalableTag<float> d;
|
||||||
for (size_t i = 0; i < size; ++i) {
|
const size_t N = hn::Lanes(d);
|
||||||
total += a[i] * a[i];
|
HWY_DASSERT(size >= 2 * N);
|
||||||
|
HWY_DASSERT(size % (2 * N) == 0);
|
||||||
|
|
||||||
|
auto sum0 = hn::Zero(d);
|
||||||
|
auto sum1 = hn::Zero(d);
|
||||||
|
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
||||||
|
const auto a0 = LoadU(d, a + i);
|
||||||
|
sum0 = MulAdd(a0, a0, sum0);
|
||||||
|
const auto a1 = LoadU(d, a + i + N);
|
||||||
|
sum1 = MulAdd(a1, a1, sum1);
|
||||||
}
|
}
|
||||||
return total;
|
|
||||||
|
return ReduceSum(d, Add(sum0, sum1));
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
|
@ -362,12 +372,30 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
constexpr float eps = 1e-6f;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
float ss = SquaredL2(x, size);
|
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
constexpr float kEps = 1e-6f;
|
||||||
for (size_t j = 0; j < size; j++) {
|
constexpr size_t kUnrollSize = 2;
|
||||||
// Note 1.0f centering here
|
|
||||||
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
|
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||||
|
const hn::Repartition<float, decltype(dbf)> df32;
|
||||||
|
const size_t N32 = hn::Lanes(df32);
|
||||||
|
|
||||||
|
const float ss = SquaredL2(x, size);
|
||||||
|
const auto vss =
|
||||||
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||||
|
|
||||||
|
HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0);
|
||||||
|
for (size_t i = 0; i < size; i += kUnrollSize * N32) {
|
||||||
|
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
||||||
|
const auto w0 = hn::PromoteLowerTo(df32, w16);
|
||||||
|
const auto w1 = hn::PromoteUpperTo(df32, w16);
|
||||||
|
const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
||||||
|
const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
||||||
|
|
||||||
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
|
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i);
|
||||||
|
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
6
run.cc
6
run.cc
|
|
@ -66,8 +66,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
<< std::thread::hardware_concurrency() << std::endl
|
<< std::thread::hardware_concurrency() << std::endl
|
||||||
<< "Instruction set : "
|
<< "Instruction set : "
|
||||||
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
||||||
<< hwy::VectorBytes() * 8 << " bits)"
|
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
||||||
<< "\n"
|
<< "Compiled config : " << CompiledConfig() << "\n"
|
||||||
<< "Weight Type : "
|
<< "Weight Type : "
|
||||||
<< gcpp::TypeName(gcpp::WeightT()) << "\n"
|
<< gcpp::TypeName(gcpp::WeightT()) << "\n"
|
||||||
<< "EmbedderInput Type : "
|
<< "EmbedderInput Type : "
|
||||||
|
|
@ -119,7 +119,7 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
||||||
verbosity](int token, float) {
|
verbosity](int token, float) {
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
++current_pos;
|
++current_pos;
|
||||||
if (current_pos < prompt_size) {
|
if (current_pos <= prompt_size) {
|
||||||
std::cerr << "." << std::flush;
|
std::cerr << "." << std::flush;
|
||||||
} else if (token == gcpp::EOS_ID) {
|
} else if (token == gcpp::EOS_ID) {
|
||||||
if (!args.multiturn) {
|
if (!args.multiturn) {
|
||||||
|
|
|
||||||
18
util/app.h
18
util/app.h
|
|
@ -36,6 +36,24 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
static inline const char* CompiledConfig() {
|
||||||
|
if (HWY_IS_ASAN) {
|
||||||
|
return "asan";
|
||||||
|
} else if (HWY_IS_MSAN) {
|
||||||
|
return "msan";
|
||||||
|
} else if (HWY_IS_TSAN) {
|
||||||
|
return "tsan";
|
||||||
|
#if defined(HWY_IS_UBSAN)
|
||||||
|
} else if (HWY_IS_UBSAN) {
|
||||||
|
return "ubsan";
|
||||||
|
#endif
|
||||||
|
} else if (HWY_IS_DEBUG_BUILD) {
|
||||||
|
return "dbg";
|
||||||
|
} else {
|
||||||
|
return "opt";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static inline void PinThreadToCore(size_t cpu_index) {
|
static inline void PinThreadToCore(size_t cpu_index) {
|
||||||
#if HWY_OS_LINUX
|
#if HWY_OS_LINUX
|
||||||
// Forces the thread to run on the logical processor with the same number.
|
// Forces the thread to run on the logical processor with the same number.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue