Small code cleanup suggestions while reading the code.

PiperOrigin-RevId: 641220788
This commit is contained in:
Daniel Keysers 2024-06-07 05:32:50 -07:00 committed by Copybara-Service
parent f7ac7092d6
commit 06f814fc8b
2 changed files with 14 additions and 10 deletions

View File

@ -44,7 +44,9 @@ struct KVCache {
static KVCache Create(Model type);
};
// The tokenizer's end of sentence and beginning of sentence token ids.
constexpr int EOS_ID = 1;
constexpr int BOS_ID = 2;
class GemmaTokenizer {
public:
@ -87,7 +89,7 @@ struct RuntimeConfig {
struct TimingInfo {
double prefill_tok_sec = 0.0;
double gen_tok_sec = 0.0;
double time_to_first_token = 0;
double time_to_first_token = 0.0;
};
// Will be called for layers output with:

View File

@ -25,6 +25,7 @@
// Placeholder for internal header, do not modify.
#include "compression/compress.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
@ -98,6 +99,7 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
std::cerr << "\n";
}
// The main Read-Eval-Print Loop.
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity,
@ -160,7 +162,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
std::cout << "> " << std::flush;
}
if (eot_line.size() == 0) {
if (eot_line.empty()) {
std::getline(std::cin, prompt_string);
} else {
std::string line;
@ -198,7 +200,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
if (abs_pos == 0) {
prompt.insert(prompt.begin(), 2);
prompt.insert(prompt.begin(), gcpp::BOS_ID);
}
prompt_size = prompt.size();
@ -207,7 +209,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
<< "[ Reading prompt ] " << std::flush;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < static_cast<int>(prompt.size()); ++i) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
@ -253,11 +255,6 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
KVCache kv_cache = KVCache::Create(loader.ModelType());
if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}
if (app.verbosity >= 1) {
const std::string instructions =
"*Usage*\n"
@ -307,6 +304,11 @@ int main(int argc, char** argv) {
HWY_ABORT("\nInvalid args: %s", error);
}
if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app);
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(loader, inference, app);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.