mirror of https://github.com/google/gemma.cpp.git
Small code cleanup suggestions while reading the code.
PiperOrigin-RevId: 641220788
This commit is contained in:
parent
f7ac7092d6
commit
06f814fc8b
|
|
@ -44,7 +44,9 @@ struct KVCache {
|
|||
static KVCache Create(Model type);
|
||||
};
|
||||
|
||||
// The tokenizer's end of sentence and beginning of sentence token ids.
|
||||
constexpr int EOS_ID = 1;
|
||||
constexpr int BOS_ID = 2;
|
||||
|
||||
class GemmaTokenizer {
|
||||
public:
|
||||
|
|
@ -87,7 +89,7 @@ struct RuntimeConfig {
|
|||
struct TimingInfo {
|
||||
double prefill_tok_sec = 0.0;
|
||||
double gen_tok_sec = 0.0;
|
||||
double time_to_first_token = 0;
|
||||
double time_to_first_token = 0.0;
|
||||
};
|
||||
|
||||
// Will be called for layers output with:
|
||||
|
|
|
|||
20
gemma/run.cc
20
gemma/run.cc
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "util/app.h"
|
||||
|
|
@ -98,12 +99,13 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
|||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const InferenceArgs& args, int verbosity,
|
||||
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // absolute token index over all turns
|
||||
size_t abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
int prompt_size{};
|
||||
|
||||
|
|
@ -160,7 +162,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
std::cout << "> " << std::flush;
|
||||
}
|
||||
|
||||
if (eot_line.size() == 0) {
|
||||
if (eot_line.empty()) {
|
||||
std::getline(std::cin, prompt_string);
|
||||
} else {
|
||||
std::string line;
|
||||
|
|
@ -198,7 +200,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
if (abs_pos == 0) {
|
||||
prompt.insert(prompt.begin(), 2);
|
||||
prompt.insert(prompt.begin(), gcpp::BOS_ID);
|
||||
}
|
||||
|
||||
prompt_size = prompt.size();
|
||||
|
|
@ -207,7 +209,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
<< "[ Reading prompt ] " << std::flush;
|
||||
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < static_cast<int>(prompt.size()); ++i) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
|
@ -253,11 +255,6 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
KVCache kv_cache = KVCache::Create(loader.ModelType());
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
ShowHelp(loader, inference, app);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
if (app.verbosity >= 1) {
|
||||
const std::string instructions =
|
||||
"*Usage*\n"
|
||||
|
|
@ -307,6 +304,11 @@ int main(int argc, char** argv) {
|
|||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
ShowHelp(loader, inference, app);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
gcpp::Run(loader, inference, app);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
|
|
|
|||
Loading…
Reference in New Issue