mirror of https://github.com/google/gemma.cpp.git
Decouple gemma constructor from loader args, update hello_world example, add convenience version of constructor (no uncompressed weights)
This commit is contained in:
parent
42e53e2da8
commit
dfd2fdc1dd
|
|
@ -22,7 +22,18 @@ FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.gi
|
|||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG e781007836ec034236e90cc4d313d0a8c481bce6)
|
||||
|
||||
|
||||
# Allow for both local and remote building)
|
||||
option(BUILD_MODE "'local' or 'remote' git fetch for builds")
|
||||
if (NOT BUILD_MODE)
|
||||
set(BUILD_MODE "remote")
|
||||
endif()
|
||||
if (BUILD_MODE STREQUAL "local")
|
||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||
else()
|
||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 42e53e2da89f80dc46399c7037fbbfb15cdc3de3)
|
||||
endif()
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
|
|
|
|||
|
|
@ -22,15 +22,17 @@ std::vector<int> tokenize(
|
|||
int main(int argc, char** argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
|
||||
// A rough heuristic number of threads to use
|
||||
// Rough heuristic for the number of threads to use
|
||||
size_t num_threads = static_cast<size_t>(std::clamp(
|
||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||
hwy::ThreadPool pool(num_threads);
|
||||
|
||||
// Instantiate model
|
||||
gcpp::Gemma model(loader, pool);
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool);
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
size_t pos = 0; // KV Cache position
|
||||
|
||||
// Setup random number generator
|
||||
// Initialize random number generator
|
||||
std::mt19937 gen;
|
||||
std::random_device rd;
|
||||
gen.seed(rd());
|
||||
|
|
@ -39,7 +41,6 @@ int main(int argc, char** argv) {
|
|||
std::vector<int> tokens =
|
||||
tokenize("Write a greeting to the world.", model.Tokenizer());
|
||||
size_t ntokens = tokens.size();
|
||||
size_t pos = 0;
|
||||
|
||||
// Callback
|
||||
auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()](
|
||||
|
|
@ -60,6 +61,7 @@ int main(int argc, char** argv) {
|
|||
.max_generated_tokens = 1024,
|
||||
.temperature = 1.0,
|
||||
.verbosity = 0},
|
||||
tokens, /*KV cache position = */ 0, pool, stream_token, gen);
|
||||
tokens, /*KV cache position = */ 0, kv_cache, pool,
|
||||
stream_token, gen);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
|
|
|||
17
gemma.cc
17
gemma.cc
|
|
@ -285,7 +285,6 @@ struct GemmaImpl : public GemmaInterface {
|
|||
int verbosity) override;
|
||||
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||
|
|
@ -803,15 +802,15 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
|||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||
const Model model_type = args.ModelType();
|
||||
model_training = args.ModelTraining();
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
const Path& weights_path, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer =
|
||||
std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok());
|
||||
HWY_ASSERT(tokenizer->Load(tokenizer_path.path).ok());
|
||||
auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(
|
||||
args.ModelType(), args.model, args.cache, pool);
|
||||
model_type, weights_path, compressed_weights_path, pool);
|
||||
switch (model_type) {
|
||||
case Model::GEMMA_2B:
|
||||
impl_.reset(
|
||||
|
|
@ -825,6 +824,12 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
|||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||
}
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
Model model_type, hwy::ThreadPool& pool)
|
||||
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
|
||||
pool) {}
|
||||
|
||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||
|
|
|
|||
5
gemma.h
5
gemma.h
|
|
@ -192,7 +192,10 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
struct GemmaInterface;
|
||||
|
||||
struct Gemma {
|
||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
|
||||
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
|
||||
Model model_type, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
|
|
|
|||
2
run.cc
2
run.cc
|
|
@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader, pool);
|
||||
gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool);
|
||||
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue