mirror of https://github.com/google/gemma.cpp.git
[WIP] update GemmaInterface, Gemma, and Generate input parameter specs to remove InferenceArgs. TODO: update hello_world example after git commit hash is available for fetching
This commit is contained in:
parent
0f6a4b49d5
commit
7042316013
|
|
@ -66,7 +66,7 @@ int main(int argc, char** argv) {
|
|||
inference.multiturn = false;
|
||||
|
||||
GenerateGemma(
|
||||
model, inference, tokens, 0, pool, inner_pool, stream_token,
|
||||
model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token,
|
||||
[](int) {return true;}, gen, 0);
|
||||
|
||||
std::cout << std::endl;
|
||||
|
|
|
|||
82
gemma.cc
82
gemma.cc
|
|
@ -233,9 +233,10 @@ struct GemmaInterface {
|
|||
|
||||
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
||||
|
||||
virtual void Generate(const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) = 0;
|
||||
|
|
@ -258,7 +259,8 @@ struct GemmaImpl : public GemmaInterface {
|
|||
return tokenizer.get();
|
||||
}
|
||||
|
||||
void Generate(const InferenceArgs& args, const std::vector<int>& prompt,
|
||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
|
||||
|
|
@ -295,7 +297,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
static constexpr size_t kModelDim =
|
||||
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static const float kQueryScale = static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
static const float kQueryScale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
// linear projections to QKV
|
||||
|
|
@ -418,7 +421,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
hwy::ThreadPool& inner_pool) {
|
||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
static const float kEmbScaling =
|
||||
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
|
||||
pool.Run(
|
||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
|
|
@ -473,7 +477,8 @@ void Transformer(int token, size_t pos,
|
|||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
|
||||
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
static const float kEmbScaling =
|
||||
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
|
||||
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data(), kModelDim);
|
||||
|
|
@ -604,24 +609,26 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
}
|
||||
}
|
||||
|
||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, const InferenceArgs& args,
|
||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, start_pos, pool, inner_pool,
|
||||
stream_token, accept_token, gen, verbosity);
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, pool, inner_pool, stream_token, accept_token, gen,
|
||||
verbosity);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, const InferenceArgs& args,
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, start_pos, pool, inner_pool,
|
||||
stream_token, accept_token, gen, verbosity);
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, pool, inner_pool, stream_token, accept_token, gen,
|
||||
verbosity);
|
||||
}
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||
|
|
@ -755,28 +762,24 @@ GemmaImpl<Config>::GemmaImpl(
|
|||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(const InferenceArgs& args,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
|
||||
gen, verbosity);
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(const InferenceArgs& args,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token,
|
||||
gen, verbosity);
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||
|
|
@ -807,15 +810,16 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
|||
return impl_->Tokenizer();
|
||||
}
|
||||
|
||||
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
gemma.impl_->Generate(max_tokens, max_generated_tokens,
|
||||
temperature, prompt, start_pos, pool, inner_pool,
|
||||
stream_token, accept_token, gen, verbosity);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
|
|
|
|||
14
gemma.h
14
gemma.h
|
|
@ -156,9 +156,6 @@ struct Gemma {
|
|||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
|
||||
// TODO: cleanup
|
||||
// const sentencepiece::SentencePieceProcessor& Tokenizer() const;
|
||||
// const std::unique_ptr<sentencepiece::SentencePieceProcessor> Tokenizer() const;
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
|
|
@ -205,15 +202,16 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", false,
|
||||
"Multiturn mode\n 0 = clear KV cache after every "
|
||||
"interaction\n 1 = continue KV cache after every interaction\n Default : 0 (conversation "
|
||||
"interaction\n 1 = continue KV cache after every interaction\n "
|
||||
" Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
}
|
||||
};
|
||||
|
||||
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& g,
|
||||
int verbosity);
|
||||
|
||||
|
|
|
|||
5
run.cc
5
run.cc
|
|
@ -204,8 +204,9 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
std::cerr << std::endl << "[ Reading prompt ] " << std::flush;
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, abs_pos, pool, inner_pool,
|
||||
stream_token, accept_token, gen, verbosity);
|
||||
const double time_end = hwy::platform::Now();
|
||||
const double tok_sec = current_pos / (time_end - time_start);
|
||||
if (verbosity >= 2) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue