No public description

PiperOrigin-RevId: 617315030
This commit is contained in:
Eric Ye 2024-03-19 23:35:58 +01:00 committed by Jan Wassenberg
parent 7d5364bb80
commit ffd02c59ad
9 changed files with 41 additions and 27 deletions

View File

@ -72,4 +72,4 @@ jobs:
with: with:
path: ~/.cache/bazel path: ~/.cache/bazel
key: bazel-${{ runner.os }} key: bazel-${{ runner.os }}
- run: bazel build --cxxopt=-std=c++20 //... - run: bazel build -c opt --cxxopt=-std=c++20 //...

View File

@ -4,9 +4,7 @@
load("@rules_license//rules:license.bzl", "license") load("@rules_license//rules:license.bzl", "license")
package( package(
default_applicable_licenses = [ default_applicable_licenses = ["//:license"],
"//:license", # Placeholder comment, do not modify
],
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
) )

View File

@ -1,4 +1,3 @@
# Required for referencing bazel:com_google_sentencepiece.patch
package( package(
default_applicable_licenses = ["//:license"], default_applicable_licenses = ["//:license"],
default_visibility = ["//:__subpackages__"], default_visibility = ["//:__subpackages__"],

View File

@ -1,12 +1,10 @@
# Weight compression, I/O and analysis # Weight compression, I/O and analysis
package( package(
default_applicable_licenses = [ default_applicable_licenses = ["//:license"],
"//:license", # Placeholder comment, do not modify
],
default_visibility = [ default_visibility = [
# Placeholder for internal visibility, "//learning/gemini/prod/contrib/gemini_cpp:__subpackages__",
"//:__subpackages__", # Placeholder, do not modify "//:__subpackages__",
], ],
) )

View File

@ -17,10 +17,13 @@
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "gemma.h" #include "gemma.h"
// copybara:import_next_line:gemma_cpp // copybara:end
#include "util/app.h" // LoaderArgs
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/app.h" // LoaderArgs
// copybara:end
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
std::vector<int> tokenize( std::vector<int> tokenize(

View File

@ -25,8 +25,6 @@
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "ops.h" #include "ops.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -52,8 +50,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
// Placeholder for internal header, do not modify.
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" #include "compression/compress.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
@ -817,9 +813,8 @@ void GemmaImpl<ConfigGemma7B>::Generate(
} }
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Path& weights_path, Model model_type, ModelTraining training, const Path& weights_path, Model model_type,
hwy::ThreadPool& pool) hwy::ThreadPool& pool) {
: model_training(training) {
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
{ {
PROFILER_ZONE("Startup.tokenizer"); PROFILER_ZONE("Startup.tokenizer");
@ -844,6 +839,11 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
} }
} }
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
Model model_type, hwy::ThreadPool& pool)
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
pool) {}
Gemma::~Gemma() = default; // after GemmaInterface is defined Gemma::~Gemma() = default; // after GemmaInterface is defined
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {

16
gemma.h
View File

@ -16,20 +16,29 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#include <algorithm>
#include <cctype>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <random> #include <random>
#include <string>
#include <vector> #include <vector>
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream #include "compression/compress.h" // SfpStream/NuqStream
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path #include "configs.h" // kSeqLen
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // ArgsBase
// copybara:end
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece // copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h" #include "src/sentencepiece_processor.h"
// copybara:end
namespace gcpp { namespace gcpp {
@ -66,8 +75,9 @@ struct GemmaInterface;
struct Gemma { struct Gemma {
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Path& weights_path, Model model_type, ModelTraining training, const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
hwy::ThreadPool& pool); Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
Model model_type, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined. ~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor* Tokenizer() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;

11
run.cc
View File

@ -22,15 +22,18 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
// Placeholder for internal header, do not modify.
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" #include "compression/compress.h"
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma #include "gemma.h" // Gemma
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/app.h" #include "util/app.h"
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp #include "util/args.h" // HasHelp
// copybara:end
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -231,8 +234,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
} }
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
loader.ModelType(), loader.ModelTraining(), pool); loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType()); auto kv_cache = CreateKVCache(loader.ModelType());
@ -274,8 +277,6 @@ int main(int argc, char** argv) {
{ {
PROFILER_ZONE("Startup.misc"); PROFILER_ZONE("Startup.misc");
// Placeholder for internal init, do not modify.
gcpp::LoaderArgs loader(argc, argv); gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv); gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv); gcpp::AppArgs app(argc, argv);

View File

@ -34,10 +34,15 @@
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "configs.h" #include "configs.h"
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "gemma.h" #include "gemma.h"
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
// copybara:end
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
namespace gcpp { namespace gcpp {