Fix build for RPi, missing hn::. Refs #112, thanks long568

PiperOrigin-RevId: 617704418
This commit is contained in:
Jan Wassenberg 2024-03-20 20:07:12 -07:00 committed by Copybara-Service
parent 06cea2bcdb
commit 30b8a3c1ac
4 changed files with 12 additions and 15 deletions

View File

@ -5,7 +5,7 @@ package(
"//:license", # Placeholder comment, do not modify "//:license", # Placeholder comment, do not modify
], ],
default_visibility = [ default_visibility = [
"//learning/gemini/prod/contrib/gemini_cpp:__subpackages__", # Placeholder for internal visibility,
"//:__subpackages__", # Placeholder, do not modify "//:__subpackages__", # Placeholder, do not modify
], ],
) )

View File

@ -52,7 +52,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "base/init_google.h" // Placeholder for internal header, do not modify.
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" #include "compression/compress.h"

15
ops.h
View File

@ -341,20 +341,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
const float* HWY_RESTRICT a, size_t size) { const float* HWY_RESTRICT a, size_t size) {
const hn::ScalableTag<float> d; const hn::ScalableTag<float> d;
using V = hn::Vec<decltype(d)>;
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d);
HWY_DASSERT(size >= 2 * N); HWY_DASSERT(size >= 2 * N);
HWY_DASSERT(size % (2 * N) == 0); HWY_DASSERT(size % (2 * N) == 0);
auto sum0 = hn::Zero(d); V sum0 = hn::Zero(d);
auto sum1 = hn::Zero(d); V sum1 = hn::Zero(d);
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
const auto a0 = LoadU(d, a + i); const V a0 = hn::LoadU(d, a + i);
sum0 = MulAdd(a0, a0, sum0); sum0 = hn::MulAdd(a0, a0, sum0);
const auto a1 = LoadU(d, a + i + N); const V a1 = hn::LoadU(d, a + i + N);
sum1 = MulAdd(a1, a1, sum1); sum1 = hn::MulAdd(a1, a1, sum1);
} }
return ReduceSum(d, Add(sum0, sum1)); return hn::ReduceSum(d, hn::Add(sum0, sum1));
} }
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(

8
run.cc
View File

@ -22,7 +22,7 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "base/init_google.h" // Placeholder for internal header, do not modify.
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" #include "compression/compress.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
@ -274,11 +274,7 @@ int main(int argc, char** argv) {
{ {
PROFILER_ZONE("Startup.misc"); PROFILER_ZONE("Startup.misc");
int argc_dummy = 1; // Placeholder for internal init, do not modify.
// Required because sentencepiece uses Google I/O which requires InitGoogle.
// argc_dummy = 1 avoids sentencepiece absl flags attempting to parse
// arguments
InitGoogle("usage", &argc_dummy, &argv, false);
gcpp::LoaderArgs loader(argc, argv); gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv); gcpp::InferenceArgs inference(argc, argv);