[WIP] decouple GemmaImpl from CLI args

This commit is contained in:
austinvhuang 2024-03-06 15:06:41 -05:00
parent c378ac2c56
commit 10f7a086aa
6 changed files with 78 additions and 48 deletions

View File

@ -0,0 +1,2 @@
*
!.gitignore

View File

@ -12,6 +12,9 @@
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" // HasHelp #include "util/args.h" // HasHelp
// copybara:end // copybara:end
// copybara:import_next_line:gemma_cpp
#include "configs.h"
// copybara:end
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -35,17 +38,13 @@ int main(int argc, char** argv) {
hwy::ThreadPool pool(app.num_threads); hwy::ThreadPool pool(app.num_threads);
hwy::ThreadPool inner_pool(0); hwy::ThreadPool inner_pool(0);
gcpp::Gemma model(loader, pool); gcpp::Gemma model(loader, pool);
std::vector<int> tokens = tokenize("Hello, how are you?", model.Tokenizer());
std::mt19937 gen; std::mt19937 gen;
std::random_device rd; std::random_device rd;
gen.seed(rd()); gen.seed(rd());
std::vector<int> tokens = tokenize("Hello, how are you?", model.Tokenizer());
size_t ntokens = tokens.size(); size_t ntokens = tokens.size();
size_t pos = 0; size_t pos = 0;
auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) { auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) {
++pos; ++pos;
if (pos < ntokens) { if (pos < ntokens) {

View File

@ -26,11 +26,11 @@
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "ops.h" #include "ops.h"
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
#include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
#include "util/args.h" // Path
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first. // compile pass, whereas we want this defined in the first.
@ -231,9 +231,8 @@ struct Activations {
struct GemmaInterface { struct GemmaInterface {
virtual ~GemmaInterface() = default; virtual ~GemmaInterface() = default;
virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0; virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
// TODO: group pool/callbacks into struct
virtual void Generate(const InferenceArgs& args, virtual void Generate(const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
@ -244,7 +243,10 @@ struct GemmaInterface {
template <class Config> template <class Config>
struct GemmaImpl : public GemmaInterface { struct GemmaImpl : public GemmaInterface {
GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool); GemmaImpl( // const LoaderArgs& args,
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
hwy::ThreadPool& pool);
~GemmaImpl() { ~GemmaImpl() {
using CWeights = CompressedWeights<Config>; using CWeights = CompressedWeights<Config>;
@ -252,8 +254,8 @@ struct GemmaImpl : public GemmaInterface {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>(); c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
} }
const sentencepiece::SentencePieceProcessor& Tokenizer() const { const sentencepiece::SentencePieceProcessor* Tokenizer() const {
return tokenizer; return tokenizer.get();
} }
void Generate(const InferenceArgs& args, const std::vector<int>& prompt, void Generate(const InferenceArgs& args, const std::vector<int>& prompt,
@ -261,9 +263,8 @@ struct GemmaImpl : public GemmaInterface {
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937&, int verbosity); const AcceptFunc& accept_token, std::mt19937&, int verbosity);
sentencepiece::SentencePieceProcessor tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
// CompressedWeights<Config>
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights; hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill; hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
hwy::AlignedUniquePtr<Activations<Config, 1>> state; hwy::AlignedUniquePtr<Activations<Config, 1>> state;
@ -495,7 +496,8 @@ void Transformer(int token, size_t pos,
} }
template <class TConfig> template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args, void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t pos, const std::vector<int>& prompt, size_t pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
@ -549,7 +551,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
// should be available as observable state for frontend code to handle I/O. // should be available as observable state for frontend code to handle I/O.
double prefill_end = hwy::platform::Now(); double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start); const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n"; std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
} }
double gen_start = hwy::platform::Now(); double gen_start = hwy::platform::Now();
@ -558,10 +560,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
if (verbosity >= 2) { if (verbosity >= 2) {
// Provide usage warnings if max_new_tokens is out of range. // Provide usage warnings if max_new_tokens is out of range.
if (args.max_generated_tokens > args.max_tokens) { if (max_generated_tokens > max_tokens) {
std::cout << "Warning: max_new_tokens should be <= max_tokens" std::cout << "Warning: max_new_tokens should be <= max_tokens"
<< std::endl; << std::endl;
} else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) { } else if ((prompt.size() + max_generated_tokens) > max_tokens) {
std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens." std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens."
<< std::endl; << std::endl;
} }
@ -570,7 +572,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
auto pos_gen_start = pos_offset; auto pos_gen_start = pos_offset;
token = prompt.at(pos_offset); token = prompt.at(pos_offset);
size_t generate_pos = 0; size_t generate_pos = 0;
for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens; for (; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data(); float* final_activation = activations.x.data();
@ -583,7 +585,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
// Barrier: must have all logits so we can subtract max. // Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize); Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<kTopK>(activations.logits.data(), kVocabSize, gen, token = SampleTopK<kTopK>(activations.logits.data(), kVocabSize, gen,
args.temperature, accept_token); temperature, accept_token);
} }
if (!stream_token(token, activations.logits[token])) { if (!stream_token(token, activations.logits[token])) {
token = EOS_ID; token = EOS_ID;
@ -593,7 +595,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
double gen_end = hwy::platform::Now(); double gen_end = hwy::platform::Now();
const double gen_tok_sec = const double gen_tok_sec =
(pos_offset - pos_gen_start) / (gen_end - gen_start); (pos_offset - pos_gen_start) / (gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; std::cout << "[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
} }
break; break;
} }
@ -605,8 +607,9 @@ void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, const InferenceArgs& args,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) { std::mt19937& gen, int verbosity) {
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens,
accept_token, gen, verbosity); args.temperature, prompt, start_pos, pool, inner_pool,
stream_token, accept_token, gen, verbosity);
} }
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args, void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
@ -614,8 +617,9 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) { std::mt19937& gen, int verbosity) {
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens,
accept_token, gen, verbosity); args.temperature, prompt, start_pos, pool, inner_pool,
stream_token, accept_token, gen, verbosity);
} }
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null // Calls func(name, float*, CompressedArray&) for each tensor. float* is null
@ -729,17 +733,22 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
} }
template <class Config> template <class Config>
GemmaImpl<Config>::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool) GemmaImpl<Config>::GemmaImpl(
: compressed_weights( std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
hwy::ThreadPool& pool)
// GemmaImpl<Config>::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool&
// pool)
: compressed_weights(std::move(compressed_weights)),
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()), prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()), state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
kv_cache( kv_cache(
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen)) { Config::kSeqLen)),
PROFILER_ZONE("Startup.tokenizer"); tokenizer(std::move(tokenizer)) {
// PROFILER_ZONE("Startup.tokenizer");
HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
} }
template <> template <>
@ -770,12 +779,20 @@ void GemmaImpl<ConfigGemma7B>::Generate(const InferenceArgs& args,
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
const Model model_type = args.ModelType(); const Model model_type = args.ModelType();
model_training = args.ModelTraining(); model_training = args.ModelTraining();
PROFILER_ZONE("Startup.tokenizer");
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer =
std::make_unique<sentencepiece::SentencePieceProcessor>();
HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok());
auto compressed_weights =
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool);
switch (model_type) { switch (model_type) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
impl_.reset(new GemmaImpl<ConfigGemma2B>(args, pool)); impl_.reset(
new GemmaImpl<ConfigGemma2B>(tokenizer, compressed_weights, pool));
break; break;
case Model::GEMMA_7B: case Model::GEMMA_7B:
impl_.reset(new GemmaImpl<ConfigGemma7B>(args, pool)); impl_.reset(
new GemmaImpl<ConfigGemma7B>(tokenizer, compressed_weights, pool));
break; break;
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
@ -783,7 +800,7 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
} }
Gemma::~Gemma() = default; // after GemmaInterface is defined Gemma::~Gemma() = default; // after GemmaInterface is defined
const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const { const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
return impl_->Tokenizer(); return impl_->Tokenizer();
} }

24
gemma.h
View File

@ -64,6 +64,15 @@ struct KVCache {
enum class Model { GEMMA_2B, GEMMA_7B }; enum class Model { GEMMA_2B, GEMMA_7B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT }; enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// TODO: incorporate
struct InferenceParams {
Model model;
ModelTraining model_training;
size_t max_generated_tokens;
size_t max_tokens;
float temperature;
};
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
@ -129,9 +138,9 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"file if " "file if "
"the compressed weights file does not exist.\n Required argument."); "the compressed weights file does not exist.\n Required argument.");
visitor(model_type, "model", std::string(), visitor(model_type, "model", std::string(),
"Model type\n 2b-it (2B parameters, instruction-tuned)\n " "Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters " "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n" "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
" Required argument."); " Required argument.");
visitor(model, "weights", Path(), visitor(model, "weights", Path(),
"Path name of model weights (.sbs) file. Only required if " "Path name of model weights (.sbs) file. Only required if "
@ -147,7 +156,10 @@ struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined. ~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor& Tokenizer() const; // TODO: cleanup
// const sentencepiece::SentencePieceProcessor& Tokenizer() const;
// const std::unique_ptr<sentencepiece::SentencePieceProcessor> Tokenizer() const;
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training; gcpp::ModelTraining model_training;
@ -192,8 +204,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
visitor(deterministic, "deterministic", false, visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2); "Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false, visitor(multiturn, "multiturn", false,
"Multiturn mode (if 0, this clears the KV cache after every " "Multiturn mode\n 0 = clear KV cache after every "
"interaction without quitting)\n Default : 0 (conversation " "interaction\n 1 = continue KV cache after every interaction\n Default : 0 (conversation "
"resets every turn)"); "resets every turn)");
} }
}; };

8
run.cc
View File

@ -115,7 +115,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
// callback function invoked for each generated token. // callback function invoked for each generated token.
auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size, auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size,
tokenizer = &model.Tokenizer(), tokenizer = model.Tokenizer(),
verbosity](int token, float) { verbosity](int token, float) {
++abs_pos; ++abs_pos;
++current_pos; ++current_pos;
@ -129,7 +129,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
} }
} }
if (verbosity >= 2) { if (verbosity >= 2) {
std::cout << "\n[ End ]" << std::endl; std::cout << "\n[ End ]\n";
} }
} else { } else {
std::string token_text; std::string token_text;
@ -142,7 +142,6 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
std::cout << std::endl << std::endl; std::cout << std::endl << std::endl;
} }
} }
// TODO(austinvhuang): is explicit space necessary?
std::cout << token_text << std::flush; std::cout << token_text << std::flush;
} }
return true; return true;
@ -191,7 +190,8 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
} }
} }
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); // HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok());
// For both pre-trained and instruction-tuned models: prepend "<bos>" token // For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed. // if needed.