Fix Py binding/run_example: use GemmaEnv

PiperOrigin-RevId: 644318962
This commit is contained in:
Jan Wassenberg 2024-06-18 03:19:48 -07:00 committed by Copybara-Service
parent a07f60c9a1
commit 2ac47e4a06
2 changed files with 22 additions and 16 deletions

View File

@ -53,15 +53,12 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
} }
} }
GemmaEnv::GemmaEnv(int argc, char** argv) GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
: loader_(argc, argv), const AppArgs& app)
inference_args_(argc, argv), : loader_(loader),
app_(argc, argv), inference_args_(inference),
app_(app),
pool_(app_.num_threads) { pool_(app_.num_threads) {
{
// Placeholder for internal init, do not modify.
}
// For many-core, pinning workers to cores helps. // For many-core, pinning workers to cores helps.
if (app_.num_threads > 10) { if (app_.num_threads > 10) {
gcpp::PinWorkersToCores(pool_); gcpp::PinWorkersToCores(pool_);
@ -89,6 +86,15 @@ GemmaEnv::GemmaEnv(int argc, char** argv)
}; };
} }
// Note: the delegating ctor above is called before any other initializers here.
GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
AppArgs(argc, argv)) {
{ // So that indentation matches expectations.
// Placeholder for internal init, do not modify.
}
}
std::pair<std::string, size_t> GemmaEnv::QueryModel( std::pair<std::string, size_t> GemmaEnv::QueryModel(
const std::vector<int>& tokens) { const std::vector<int>& tokens) {
std::string res; std::string res;
@ -98,10 +104,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
const StreamFunc stream_token = [&res, &total_tokens, &time_start, this]( const StreamFunc stream_token = [&res, &total_tokens, &time_start, this](
int token, float) { int token, float) {
++total_tokens; ++total_tokens;
std::string token_text; res += StringFromTokens(std::vector<int>{token});
HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
if (app_.verbosity >= 1 && total_tokens % 128 == 0) { if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
} }

View File

@ -37,7 +37,10 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
// Convenience class to load a model and run inference. // Convenience class to load a model and run inference.
class GemmaEnv { class GemmaEnv {
public: public:
// Calls the other constructor with *Args arguments initialized from argv.
GemmaEnv(int argc, char** argv); GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app);
size_t MaxTokens() const { return inference_args_.max_tokens; } size_t MaxTokens() const { return inference_args_.max_tokens; }
// Sets the maximum number of output tokens to generate. // Sets the maximum number of output tokens to generate.
@ -81,7 +84,8 @@ class GemmaEnv {
return loader_.ModelTrainingType(); return loader_.ModelTrainingType();
} }
int Verbosity() const { return app_.verbosity; } int Verbosity() const { return app_.verbosity; }
gcpp::RuntimeConfig& MutableConfig() { return runtime_config_; } RuntimeConfig& MutableConfig() { return runtime_config_; }
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
std::mt19937& MutableGen() { return gen_; } std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return kv_cache_; } KVCache& MutableKVCache() { return kv_cache_; }
@ -100,15 +104,14 @@ class GemmaEnv {
std::unique_ptr<Gemma> model_; std::unique_ptr<Gemma> model_;
// The KV cache to use for inference. // The KV cache to use for inference.
KVCache kv_cache_; KVCache kv_cache_;
gcpp::RuntimeConfig runtime_config_; RuntimeConfig runtime_config_;
}; };
// Logs the inference speed in tokens/sec. // Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens); void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
gcpp::AppArgs& app);
} // namespace gcpp } // namespace gcpp