Decouple gemma constructor from loader args, update hello_world example, add convenience version of constructor (no uncompressed weights)

This commit is contained in:
austinvhuang 2024-03-08 17:26:03 -05:00
parent 42e53e2da8
commit dfd2fdc1dd
5 changed files with 36 additions and 15 deletions

View File

@ -22,7 +22,18 @@ FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.gi
FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG e781007836ec034236e90cc4d313d0a8c481bce6)
# Allow for both local and remote building)
option(BUILD_MODE "'local' or 'remote' git fetch for builds")
if (NOT BUILD_MODE)
set(BUILD_MODE "remote")
endif()
if (BUILD_MODE STREQUAL "local")
FetchContent_Declare(gemma SOURCE_DIR ../../..)
else()
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 42e53e2da89f80dc46399c7037fbbfb15cdc3de3)
endif()
FetchContent_MakeAvailable(gemma)
if(NOT CMAKE_BUILD_TYPE)

View File

@ -22,15 +22,17 @@ std::vector<int> tokenize(
int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
// A rough heuristic number of threads to use
// Rough heuristic for the number of threads to use
size_t num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
hwy::ThreadPool pool(num_threads);
// Instantiate model
gcpp::Gemma model(loader, pool);
// Instantiate model and KV Cache
gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
size_t pos = 0; // KV Cache position
// Setup random number generator
// Initialize random number generator
std::mt19937 gen;
std::random_device rd;
gen.seed(rd());
@ -39,7 +41,6 @@ int main(int argc, char** argv) {
std::vector<int> tokens =
tokenize("Write a greeting to the world.", model.Tokenizer());
size_t ntokens = tokens.size();
size_t pos = 0;
// Callback
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
@ -60,6 +61,7 @@ int main(int argc, char** argv) {
.max_generated_tokens = 1024,
.temperature = 1.0,
.verbosity = 0},
tokens, /*KV cache position = */ 0, pool, stream_token, gen);
tokens, /*KV cache position = */ 0, kv_cache, pool,
stream_token, gen);
std::cout << std::endl;
}

View File

@ -285,7 +285,6 @@ struct GemmaImpl : public GemmaInterface {
int verbosity) override;
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
@ -803,15 +802,15 @@ void GemmaImpl<ConfigGemma7B>::Generate(
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
const Model model_type = args.ModelType();
model_training = args.ModelTraining();
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Path& weights_path, Model model_type,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.tokenizer");
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer =
std::make_unique<sentencepiece::SentencePieceProcessor>();
HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok());
HWY_ASSERT(tokenizer->Load(tokenizer_path.path).ok());
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
args.ModelType(), args.model, args.cache, pool);
model_type, weights_path, compressed_weights_path, pool);
switch (model_type) {
case Model::GEMMA_2B:
impl_.reset(
@ -825,6 +824,12 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
}
}
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
Model model_type, hwy::ThreadPool& pool)
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
pool) {}
Gemma::~Gemma() = default; // after GemmaInterface is defined
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {

View File

@ -192,7 +192,10 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
struct GemmaInterface;
struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
Model model_type, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;

2
run.cc
View File

@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
}
gcpp::Gemma model(loader, pool);
gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());