mirror of https://github.com/google/gemma.cpp.git
Decouple gemma constructor from loader args, update hello_world example, add convenience version of constructor (no uncompressed weights)
This commit is contained in:
parent
42e53e2da8
commit
dfd2fdc1dd
|
|
@ -22,7 +22,18 @@ FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.gi
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||||
FetchContent_MakeAvailable(sentencepiece)
|
FetchContent_MakeAvailable(sentencepiece)
|
||||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG e781007836ec034236e90cc4d313d0a8c481bce6)
|
|
||||||
|
|
||||||
|
# Allow for both local and remote building)
|
||||||
|
option(BUILD_MODE "'local' or 'remote' git fetch for builds")
|
||||||
|
if (NOT BUILD_MODE)
|
||||||
|
set(BUILD_MODE "remote")
|
||||||
|
endif()
|
||||||
|
if (BUILD_MODE STREQUAL "local")
|
||||||
|
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||||
|
else()
|
||||||
|
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 42e53e2da89f80dc46399c7037fbbfb15cdc3de3)
|
||||||
|
endif()
|
||||||
FetchContent_MakeAvailable(gemma)
|
FetchContent_MakeAvailable(gemma)
|
||||||
|
|
||||||
if(NOT CMAKE_BUILD_TYPE)
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
|
|
|
||||||
|
|
@ -22,15 +22,17 @@ std::vector<int> tokenize(
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
|
|
||||||
// A rough heuristic number of threads to use
|
// Rough heuristic for the number of threads to use
|
||||||
size_t num_threads = static_cast<size_t>(std::clamp(
|
size_t num_threads = static_cast<size_t>(std::clamp(
|
||||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||||
hwy::ThreadPool pool(num_threads);
|
hwy::ThreadPool pool(num_threads);
|
||||||
|
|
||||||
// Instantiate model
|
// Instantiate model and KV Cache
|
||||||
gcpp::Gemma model(loader, pool);
|
gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool);
|
||||||
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
|
size_t pos = 0; // KV Cache position
|
||||||
|
|
||||||
// Setup random number generator
|
// Initialize random number generator
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
gen.seed(rd());
|
gen.seed(rd());
|
||||||
|
|
@ -39,7 +41,6 @@ int main(int argc, char** argv) {
|
||||||
std::vector<int> tokens =
|
std::vector<int> tokens =
|
||||||
tokenize("Write a greeting to the world.", model.Tokenizer());
|
tokenize("Write a greeting to the world.", model.Tokenizer());
|
||||||
size_t ntokens = tokens.size();
|
size_t ntokens = tokens.size();
|
||||||
size_t pos = 0;
|
|
||||||
|
|
||||||
// Callback
|
// Callback
|
||||||
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
|
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
|
||||||
|
|
@ -60,6 +61,7 @@ int main(int argc, char** argv) {
|
||||||
.max_generated_tokens = 1024,
|
.max_generated_tokens = 1024,
|
||||||
.temperature = 1.0,
|
.temperature = 1.0,
|
||||||
.verbosity = 0},
|
.verbosity = 0},
|
||||||
tokens, /*KV cache position = */ 0, pool, stream_token, gen);
|
tokens, /*KV cache position = */ 0, kv_cache, pool,
|
||||||
|
stream_token, gen);
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
17
gemma.cc
17
gemma.cc
|
|
@ -285,7 +285,6 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
int verbosity) override;
|
int verbosity) override;
|
||||||
|
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||||
|
|
@ -803,15 +802,15 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
||||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
const Model model_type = args.ModelType();
|
const Path& weights_path, Model model_type,
|
||||||
model_training = args.ModelTraining();
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Startup.tokenizer");
|
PROFILER_ZONE("Startup.tokenizer");
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer =
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer =
|
||||||
std::make_unique<sentencepiece::SentencePieceProcessor>();
|
std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||||
HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok());
|
HWY_ASSERT(tokenizer->Load(tokenizer_path.path).ok());
|
||||||
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
|
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
|
||||||
args.ModelType(), args.model, args.cache, pool);
|
model_type, weights_path, compressed_weights_path, pool);
|
||||||
switch (model_type) {
|
switch (model_type) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
impl_.reset(
|
impl_.reset(
|
||||||
|
|
@ -825,6 +824,12 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
|
Model model_type, hwy::ThreadPool& pool)
|
||||||
|
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
|
||||||
|
pool) {}
|
||||||
|
|
||||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||||
|
|
|
||||||
5
gemma.h
5
gemma.h
|
|
@ -192,7 +192,10 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
struct GemmaInterface;
|
struct GemmaInterface;
|
||||||
|
|
||||||
struct Gemma {
|
struct Gemma {
|
||||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
|
const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
|
||||||
|
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||||
|
Model model_type, hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
|
|
|
||||||
2
run.cc
2
run.cc
|
|
@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader, pool);
|
gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool);
|
||||||
|
|
||||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue