mirror of https://github.com/google/gemma.cpp.git
Small code cleanup suggestions while reading the code.
PiperOrigin-RevId: 641220788
This commit is contained in:
parent
f7ac7092d6
commit
06f814fc8b
|
|
@ -44,7 +44,9 @@ struct KVCache {
|
||||||
static KVCache Create(Model type);
|
static KVCache Create(Model type);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// The tokenizer's end of sentence and beginning of sentence token ids.
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
|
constexpr int BOS_ID = 2;
|
||||||
|
|
||||||
class GemmaTokenizer {
|
class GemmaTokenizer {
|
||||||
public:
|
public:
|
||||||
|
|
@ -87,7 +89,7 @@ struct RuntimeConfig {
|
||||||
struct TimingInfo {
|
struct TimingInfo {
|
||||||
double prefill_tok_sec = 0.0;
|
double prefill_tok_sec = 0.0;
|
||||||
double gen_tok_sec = 0.0;
|
double gen_tok_sec = 0.0;
|
||||||
double time_to_first_token = 0;
|
double time_to_first_token = 0.0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Will be called for layers output with:
|
// Will be called for layers output with:
|
||||||
|
|
|
||||||
18
gemma/run.cc
18
gemma/run.cc
|
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
// Placeholder for internal header, do not modify.
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h" // Gemma
|
#include "gemma/gemma.h" // Gemma
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
|
|
@ -98,6 +99,7 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||||
std::cerr << "\n";
|
std::cerr << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The main Read-Eval-Print Loop.
|
||||||
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const InferenceArgs& args, int verbosity,
|
const InferenceArgs& args, int verbosity,
|
||||||
|
|
@ -160,7 +162,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
std::cout << "> " << std::flush;
|
std::cout << "> " << std::flush;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (eot_line.size() == 0) {
|
if (eot_line.empty()) {
|
||||||
std::getline(std::cin, prompt_string);
|
std::getline(std::cin, prompt_string);
|
||||||
} else {
|
} else {
|
||||||
std::string line;
|
std::string line;
|
||||||
|
|
@ -198,7 +200,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||||
// if needed.
|
// if needed.
|
||||||
if (abs_pos == 0) {
|
if (abs_pos == 0) {
|
||||||
prompt.insert(prompt.begin(), 2);
|
prompt.insert(prompt.begin(), gcpp::BOS_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
|
|
@ -207,7 +209,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
<< "[ Reading prompt ] " << std::flush;
|
<< "[ Reading prompt ] " << std::flush;
|
||||||
|
|
||||||
if constexpr (kVerboseLogTokens) {
|
if constexpr (kVerboseLogTokens) {
|
||||||
for (int i = 0; i < static_cast<int>(prompt.size()); ++i) {
|
for (int i = 0; i < prompt_size; ++i) {
|
||||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -253,11 +255,6 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
|
|
||||||
KVCache kv_cache = KVCache::Create(loader.ModelType());
|
KVCache kv_cache = KVCache::Create(loader.ModelType());
|
||||||
|
|
||||||
if (const char* error = inference.Validate()) {
|
|
||||||
ShowHelp(loader, inference, app);
|
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (app.verbosity >= 1) {
|
if (app.verbosity >= 1) {
|
||||||
const std::string instructions =
|
const std::string instructions =
|
||||||
"*Usage*\n"
|
"*Usage*\n"
|
||||||
|
|
@ -307,6 +304,11 @@ int main(int argc, char** argv) {
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
HWY_ABORT("\nInvalid args: %s", error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (const char* error = inference.Validate()) {
|
||||||
|
ShowHelp(loader, inference, app);
|
||||||
|
HWY_ABORT("\nInvalid args: %s", error);
|
||||||
|
}
|
||||||
|
|
||||||
gcpp::Run(loader, inference, app);
|
gcpp::Run(loader, inference, app);
|
||||||
}
|
}
|
||||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue