mirror of https://github.com/google/gemma.cpp.git
[WIP] decouple GemmaImpl from CLI args
This commit is contained in:
parent
c378ac2c56
commit
10f7a086aa
|
|
@ -0,0 +1,2 @@
|
||||||
|
*
|
||||||
|
!.gitignore
|
||||||
|
|
@ -12,6 +12,9 @@
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h" // HasHelp
|
#include "util/args.h" // HasHelp
|
||||||
// copybara:end
|
// copybara:end
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "configs.h"
|
||||||
|
// copybara:end
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -35,17 +38,13 @@ int main(int argc, char** argv) {
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
hwy::ThreadPool inner_pool(0);
|
hwy::ThreadPool inner_pool(0);
|
||||||
gcpp::Gemma model(loader, pool);
|
gcpp::Gemma model(loader, pool);
|
||||||
|
|
||||||
std::vector<int> tokens = tokenize("Hello, how are you?", model.Tokenizer());
|
|
||||||
|
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
gen.seed(rd());
|
gen.seed(rd());
|
||||||
|
|
||||||
|
std::vector<int> tokens = tokenize("Hello, how are you?", model.Tokenizer());
|
||||||
size_t ntokens = tokens.size();
|
size_t ntokens = tokens.size();
|
||||||
|
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
|
|
||||||
auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) {
|
auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) {
|
||||||
++pos;
|
++pos;
|
||||||
if (pos < ntokens) {
|
if (pos < ntokens) {
|
||||||
|
|
|
||||||
77
gemma.cc
77
gemma.cc
|
|
@ -19,18 +19,18 @@
|
||||||
// which we pass the filename via macro 'argument'.
|
// which we pass the filename via macro 'argument'.
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
// Must come after foreach_target.h to avoid redefinition errors.
|
// Must come after foreach_target.h to avoid redefinition errors.
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h" // Path
|
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
#include "util/args.h" // Path
|
||||||
|
|
||||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||||
// compile pass, whereas we want this defined in the first.
|
// compile pass, whereas we want this defined in the first.
|
||||||
|
|
@ -231,9 +231,8 @@ struct Activations {
|
||||||
struct GemmaInterface {
|
struct GemmaInterface {
|
||||||
virtual ~GemmaInterface() = default;
|
virtual ~GemmaInterface() = default;
|
||||||
|
|
||||||
virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0;
|
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
||||||
|
|
||||||
// TODO: group pool/callbacks into struct
|
|
||||||
virtual void Generate(const InferenceArgs& args,
|
virtual void Generate(const InferenceArgs& args,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
|
|
@ -244,7 +243,10 @@ struct GemmaInterface {
|
||||||
|
|
||||||
template <class Config>
|
template <class Config>
|
||||||
struct GemmaImpl : public GemmaInterface {
|
struct GemmaImpl : public GemmaInterface {
|
||||||
GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool);
|
GemmaImpl( // const LoaderArgs& args,
|
||||||
|
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||||
|
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
|
||||||
|
hwy::ThreadPool& pool);
|
||||||
|
|
||||||
~GemmaImpl() {
|
~GemmaImpl() {
|
||||||
using CWeights = CompressedWeights<Config>;
|
using CWeights = CompressedWeights<Config>;
|
||||||
|
|
@ -252,8 +254,8 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||||
}
|
}
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor& Tokenizer() const {
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const {
|
||||||
return tokenizer;
|
return tokenizer.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generate(const InferenceArgs& args, const std::vector<int>& prompt,
|
void Generate(const InferenceArgs& args, const std::vector<int>& prompt,
|
||||||
|
|
@ -261,9 +263,8 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
|
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
|
||||||
|
|
||||||
sentencepiece::SentencePieceProcessor tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
|
|
||||||
// CompressedWeights<Config>
|
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||||
|
|
@ -495,7 +496,8 @@ void Transformer(int token, size_t pos,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
|
size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t pos,
|
const std::vector<int>& prompt, size_t pos,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
const StreamFunc& stream_token,
|
const StreamFunc& stream_token,
|
||||||
|
|
@ -549,7 +551,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
||||||
// should be available as observable state for frontend code to handle I/O.
|
// should be available as observable state for frontend code to handle I/O.
|
||||||
double prefill_end = hwy::platform::Now();
|
double prefill_end = hwy::platform::Now();
|
||||||
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
|
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
|
||||||
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n";
|
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
|
||||||
}
|
}
|
||||||
|
|
||||||
double gen_start = hwy::platform::Now();
|
double gen_start = hwy::platform::Now();
|
||||||
|
|
@ -558,10 +560,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
||||||
|
|
||||||
if (verbosity >= 2) {
|
if (verbosity >= 2) {
|
||||||
// Provide usage warnings if max_new_tokens is out of range.
|
// Provide usage warnings if max_new_tokens is out of range.
|
||||||
if (args.max_generated_tokens > args.max_tokens) {
|
if (max_generated_tokens > max_tokens) {
|
||||||
std::cout << "Warning: max_new_tokens should be <= max_tokens"
|
std::cout << "Warning: max_new_tokens should be <= max_tokens"
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
} else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) {
|
} else if ((prompt.size() + max_generated_tokens) > max_tokens) {
|
||||||
std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens."
|
std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
}
|
}
|
||||||
|
|
@ -570,7 +572,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
||||||
auto pos_gen_start = pos_offset;
|
auto pos_gen_start = pos_offset;
|
||||||
token = prompt.at(pos_offset);
|
token = prompt.at(pos_offset);
|
||||||
size_t generate_pos = 0;
|
size_t generate_pos = 0;
|
||||||
for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens;
|
for (; pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
++pos, ++pos_offset, ++generate_pos) {
|
++pos, ++pos_offset, ++generate_pos) {
|
||||||
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
|
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
|
||||||
float* final_activation = activations.x.data();
|
float* final_activation = activations.x.data();
|
||||||
|
|
@ -583,7 +585,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
||||||
// Barrier: must have all logits so we can subtract max.
|
// Barrier: must have all logits so we can subtract max.
|
||||||
Softmax(activations.logits.data(), kVocabSize);
|
Softmax(activations.logits.data(), kVocabSize);
|
||||||
token = SampleTopK<kTopK>(activations.logits.data(), kVocabSize, gen,
|
token = SampleTopK<kTopK>(activations.logits.data(), kVocabSize, gen,
|
||||||
args.temperature, accept_token);
|
temperature, accept_token);
|
||||||
}
|
}
|
||||||
if (!stream_token(token, activations.logits[token])) {
|
if (!stream_token(token, activations.logits[token])) {
|
||||||
token = EOS_ID;
|
token = EOS_ID;
|
||||||
|
|
@ -593,7 +595,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
||||||
double gen_end = hwy::platform::Now();
|
double gen_end = hwy::platform::Now();
|
||||||
const double gen_tok_sec =
|
const double gen_tok_sec =
|
||||||
(pos_offset - pos_gen_start) / (gen_end - gen_start);
|
(pos_offset - pos_gen_start) / (gen_end - gen_start);
|
||||||
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
std::cout << "[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -605,8 +607,9 @@ void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, const InferenceArgs& args,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||||
std::mt19937& gen, int verbosity) {
|
std::mt19937& gen, int verbosity) {
|
||||||
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token,
|
GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens,
|
||||||
accept_token, gen, verbosity);
|
args.temperature, prompt, start_pos, pool, inner_pool,
|
||||||
|
stream_token, accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
|
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
|
||||||
|
|
@ -614,8 +617,9 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||||
std::mt19937& gen, int verbosity) {
|
std::mt19937& gen, int verbosity) {
|
||||||
GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token,
|
GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens,
|
||||||
accept_token, gen, verbosity);
|
args.temperature, prompt, start_pos, pool, inner_pool,
|
||||||
|
stream_token, accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||||
|
|
@ -729,17 +733,22 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Config>
|
template <class Config>
|
||||||
GemmaImpl<Config>::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool)
|
GemmaImpl<Config>::GemmaImpl(
|
||||||
: compressed_weights(
|
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||||
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
|
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
|
||||||
|
hwy::ThreadPool& pool)
|
||||||
|
// GemmaImpl<Config>::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool&
|
||||||
|
// pool)
|
||||||
|
: compressed_weights(std::move(compressed_weights)),
|
||||||
|
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
|
||||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
|
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
|
||||||
kv_cache(
|
kv_cache(
|
||||||
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
||||||
Config::kSeqLen)) {
|
Config::kSeqLen)),
|
||||||
PROFILER_ZONE("Startup.tokenizer");
|
tokenizer(std::move(tokenizer)) {
|
||||||
|
// PROFILER_ZONE("Startup.tokenizer");
|
||||||
HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
|
@ -770,12 +779,20 @@ void GemmaImpl<ConfigGemma7B>::Generate(const InferenceArgs& args,
|
||||||
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||||
const Model model_type = args.ModelType();
|
const Model model_type = args.ModelType();
|
||||||
model_training = args.ModelTraining();
|
model_training = args.ModelTraining();
|
||||||
|
PROFILER_ZONE("Startup.tokenizer");
|
||||||
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer =
|
||||||
|
std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||||
|
HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok());
|
||||||
|
auto compressed_weights =
|
||||||
|
HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool);
|
||||||
switch (model_type) {
|
switch (model_type) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
impl_.reset(new GemmaImpl<ConfigGemma2B>(args, pool));
|
impl_.reset(
|
||||||
|
new GemmaImpl<ConfigGemma2B>(tokenizer, compressed_weights, pool));
|
||||||
break;
|
break;
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
impl_.reset(new GemmaImpl<ConfigGemma7B>(args, pool));
|
impl_.reset(
|
||||||
|
new GemmaImpl<ConfigGemma7B>(tokenizer, compressed_weights, pool));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||||
|
|
@ -783,7 +800,7 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||||
}
|
}
|
||||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const {
|
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||||
return impl_->Tokenizer();
|
return impl_->Tokenizer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
24
gemma.h
24
gemma.h
|
|
@ -64,6 +64,15 @@ struct KVCache {
|
||||||
enum class Model { GEMMA_2B, GEMMA_7B };
|
enum class Model { GEMMA_2B, GEMMA_7B };
|
||||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||||
|
|
||||||
|
// TODO: incorporate
|
||||||
|
struct InferenceParams {
|
||||||
|
Model model;
|
||||||
|
ModelTraining model_training;
|
||||||
|
size_t max_generated_tokens;
|
||||||
|
size_t max_tokens;
|
||||||
|
float temperature;
|
||||||
|
};
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
|
|
@ -129,9 +138,9 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
"file if "
|
"file if "
|
||||||
"the compressed weights file does not exist.\n Required argument.");
|
"the compressed weights file does not exist.\n Required argument.");
|
||||||
visitor(model_type, "model", std::string(),
|
visitor(model_type, "model", std::string(),
|
||||||
"Model type\n 2b-it (2B parameters, instruction-tuned)\n "
|
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||||
"2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters "
|
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||||
"instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n"
|
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
visitor(model, "weights", Path(),
|
visitor(model, "weights", Path(),
|
||||||
"Path name of model weights (.sbs) file. Only required if "
|
"Path name of model weights (.sbs) file. Only required if "
|
||||||
|
|
@ -147,7 +156,10 @@ struct Gemma {
|
||||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor& Tokenizer() const;
|
// TODO: cleanup
|
||||||
|
// const sentencepiece::SentencePieceProcessor& Tokenizer() const;
|
||||||
|
// const std::unique_ptr<sentencepiece::SentencePieceProcessor> Tokenizer() const;
|
||||||
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||||
|
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
gcpp::ModelTraining model_training;
|
gcpp::ModelTraining model_training;
|
||||||
|
|
@ -192,8 +204,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
visitor(deterministic, "deterministic", false,
|
visitor(deterministic, "deterministic", false,
|
||||||
"Make top-k sampling deterministic", 2);
|
"Make top-k sampling deterministic", 2);
|
||||||
visitor(multiturn, "multiturn", false,
|
visitor(multiturn, "multiturn", false,
|
||||||
"Multiturn mode (if 0, this clears the KV cache after every "
|
"Multiturn mode\n 0 = clear KV cache after every "
|
||||||
"interaction without quitting)\n Default : 0 (conversation "
|
"interaction\n 1 = continue KV cache after every interaction\n Default : 0 (conversation "
|
||||||
"resets every turn)");
|
"resets every turn)");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
8
run.cc
8
run.cc
|
|
@ -115,7 +115,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||||
|
|
||||||
// callback function invoked for each generated token.
|
// callback function invoked for each generated token.
|
||||||
auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size,
|
auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size,
|
||||||
tokenizer = &model.Tokenizer(),
|
tokenizer = model.Tokenizer(),
|
||||||
verbosity](int token, float) {
|
verbosity](int token, float) {
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
++current_pos;
|
++current_pos;
|
||||||
|
|
@ -129,7 +129,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (verbosity >= 2) {
|
if (verbosity >= 2) {
|
||||||
std::cout << "\n[ End ]" << std::endl;
|
std::cout << "\n[ End ]\n";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
|
|
@ -142,7 +142,6 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||||
std::cout << std::endl << std::endl;
|
std::cout << std::endl << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO(austinvhuang): is explicit space necessary?
|
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -191,7 +190,8 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
|
// HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
|
||||||
|
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok());
|
||||||
|
|
||||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||||
// if needed.
|
// if needed.
|
||||||
|
|
|
||||||
|
|
@ -79,9 +79,9 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
visitor(verbosity, "verbosity", 1,
|
visitor(verbosity, "verbosity", 1,
|
||||||
"Show verbose developer information\n 0 = only print generation "
|
"Show verbose developer information\n 0 = only print generation "
|
||||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||||
"developer/debug info).\n Default = 1.",
|
"developer/debug info).\n Default = 1.",
|
||||||
2);
|
2);
|
||||||
visitor(num_threads, "num_threads",
|
visitor(num_threads, "num_threads",
|
||||||
kDefaultNumThreads, // see ChooseNumThreads
|
kDefaultNumThreads, // see ChooseNumThreads
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue