[WIP] update GemmaInterface, Gemma, and Generate input parameter specs to remove InferenceArgs. TODO: update hello_world example after git commit hash is available for fetching

This commit is contained in:
austinvhuang 2024-03-06 22:22:59 -05:00
parent 0f6a4b49d5
commit 7042316013
4 changed files with 53 additions and 50 deletions

View File

@ -66,7 +66,7 @@ int main(int argc, char** argv) {
inference.multiturn = false;
GenerateGemma(
model, inference, tokens, 0, pool, inner_pool, stream_token,
model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token,
[](int) {return true;}, gen, 0);
std::cout << std::endl;

View File

@ -233,9 +233,10 @@ struct GemmaInterface {
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
virtual void Generate(const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) = 0;
@ -258,7 +259,8 @@ struct GemmaImpl : public GemmaInterface {
return tokenizer.get();
}
void Generate(const InferenceArgs& args, const std::vector<int>& prompt,
void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
@ -295,7 +297,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kHeads = TConfig::kHeads;
static const float kQueryScale = static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
// linear projections to QKV
@ -418,7 +421,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
hwy::ThreadPool& inner_pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
static constexpr size_t kModelDim = TConfig::kModelDim;
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
static const float kEmbScaling =
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
@ -473,7 +477,8 @@ void Transformer(int token, size_t pos,
static constexpr size_t kLayers = TConfig::kLayers;
static constexpr size_t kModelDim = TConfig::kModelDim;
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
static const float kEmbScaling =
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim);
@ -604,24 +609,26 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
}
}
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, const InferenceArgs& args,
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, start_pos, pool, inner_pool,
stream_token, accept_token, gen, verbosity);
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token, gen,
verbosity);
}
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, start_pos, pool, inner_pool,
stream_token, accept_token, gen, verbosity);
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token, gen,
verbosity);
}
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
@ -755,28 +762,24 @@ GemmaImpl<Config>::GemmaImpl(
}
template <>
void GemmaImpl<ConfigGemma2B>::Generate(const InferenceArgs& args,
const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
void GemmaImpl<ConfigGemma2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
gen, verbosity);
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
template <>
void GemmaImpl<ConfigGemma7B>::Generate(const InferenceArgs& args,
const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
void GemmaImpl<ConfigGemma7B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
gen, verbosity);
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
@ -807,15 +810,16 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
return impl_->Tokenizer();
}
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
gemma.impl_->Generate(max_tokens, max_generated_tokens,
temperature, prompt, start_pos, pool, inner_pool,
stream_token, accept_token, gen, verbosity);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
}

14
gemma.h
View File

@ -156,9 +156,6 @@ struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
// TODO: cleanup
// const sentencepiece::SentencePieceProcessor& Tokenizer() const;
// const std::unique_ptr<sentencepiece::SentencePieceProcessor> Tokenizer() const;
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
@ -205,15 +202,16 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
"Multiturn mode\n 0 = clear KV cache after every "
"interaction\n 1 = continue KV cache after every interaction\n Default : 0 (conversation "
"interaction\n 1 = continue KV cache after every interaction\n "
" Default : 0 (conversation "
"resets every turn)");
}
};
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& g,
int verbosity);

5
run.cc
View File

@ -204,8 +204,9 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
std::cerr << std::endl << "[ Reading prompt ] " << std::flush;
const double time_start = hwy::platform::Now();
GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, pool, inner_pool,
stream_token, accept_token, gen, verbosity);
const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start);
if (verbosity >= 2) {