From 30b8a3c1acd8c1461dd4725f483b906b323c1d0d Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 20 Mar 2024 20:07:12 -0700 Subject: [PATCH] Fix build for RPi, missing hn::. Refs #112, thanks long568 PiperOrigin-RevId: 617704418 --- compression/BUILD | 2 +- gemma.cc | 2 +- ops.h | 15 ++++++++------- run.cc | 8 ++------ 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/compression/BUILD b/compression/BUILD index 7b7040b..cfbeb99 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -5,7 +5,7 @@ package( "//:license", # Placeholder comment, do not modify ], default_visibility = [ - "//learning/gemini/prod/contrib/gemini_cpp:__subpackages__", + # Placeholder for internal visibility, "//:__subpackages__", # Placeholder, do not modify ], ) diff --git a/gemma.cc b/gemma.cc index 743e13b..a230bea 100644 --- a/gemma.cc +++ b/gemma.cc @@ -52,7 +52,7 @@ #include #include -#include "base/init_google.h" +// Placeholder for internal header, do not modify. // copybara:import_next_line:gemma_cpp #include "compression/compress.h" diff --git a/ops.h b/ops.h index 481e1d7..7aa7b62 100644 --- a/ops.h +++ b/ops.h @@ -341,20 +341,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( const float* HWY_RESTRICT a, size_t size) { const hn::ScalableTag d; + using V = hn::Vec; const size_t N = hn::Lanes(d); HWY_DASSERT(size >= 2 * N); HWY_DASSERT(size % (2 * N) == 0); - auto sum0 = hn::Zero(d); - auto sum1 = hn::Zero(d); + V sum0 = hn::Zero(d); + V sum1 = hn::Zero(d); for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { - const auto a0 = LoadU(d, a + i); - sum0 = MulAdd(a0, a0, sum0); - const auto a1 = LoadU(d, a + i + N); - sum1 = MulAdd(a1, a1, sum1); + const V a0 = hn::LoadU(d, a + i); + sum0 = hn::MulAdd(a0, a0, sum0); + const V a1 = hn::LoadU(d, a + i + N); + sum1 = hn::MulAdd(a1, a1, sum1); } - return ReduceSum(d, Add(sum0, sum1)); + return hn::ReduceSum(d, hn::Add(sum0, sum1)); } static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( diff --git a/run.cc b/run.cc index 04767d7..3f38031 100644 --- a/run.cc +++ b/run.cc @@ -22,7 +22,7 @@ #include // NOLINT #include -#include "base/init_google.h" +// Placeholder for internal header, do not modify. // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // copybara:import_next_line:gemma_cpp @@ -274,11 +274,7 @@ int main(int argc, char** argv) { { PROFILER_ZONE("Startup.misc"); - int argc_dummy = 1; - // Required because sentencepiece uses Google I/O which requires InitGoogle. - // argc_dummy = 1 avoids sentencepiece absl flags attempting to parse - // arguments - InitGoogle("usage", &argc_dummy, &argv, false); + // Placeholder for internal init, do not modify. gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv);