mirror of https://github.com/google/gemma.cpp.git
Fix Py binding/run_example: use GemmaEnv
PiperOrigin-RevId: 644318962
This commit is contained in:
parent
a07f60c9a1
commit
2ac47e4a06
|
|
@ -53,15 +53,12 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
|||
}
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||
: loader_(argc, argv),
|
||||
inference_args_(argc, argv),
|
||||
app_(argc, argv),
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
const AppArgs& app)
|
||||
: loader_(loader),
|
||||
inference_args_(inference),
|
||||
app_(app),
|
||||
pool_(app_.num_threads) {
|
||||
{
|
||||
// Placeholder for internal init, do not modify.
|
||||
}
|
||||
|
||||
// For many-core, pinning workers to cores helps.
|
||||
if (app_.num_threads > 10) {
|
||||
gcpp::PinWorkersToCores(pool_);
|
||||
|
|
@ -89,6 +86,15 @@ GemmaEnv::GemmaEnv(int argc, char** argv)
|
|||
};
|
||||
}
|
||||
|
||||
// Note: the delegating ctor above is called before any other initializers here.
|
||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
||||
AppArgs(argc, argv)) {
|
||||
{ // So that indentation matches expectations.
|
||||
// Placeholder for internal init, do not modify.
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||
const std::vector<int>& tokens) {
|
||||
std::string res;
|
||||
|
|
@ -98,10 +104,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
const StreamFunc stream_token = [&res, &total_tokens, &time_start, this](
|
||||
int token, float) {
|
||||
++total_tokens;
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
res += token_text;
|
||||
res += StringFromTokens(std::vector<int>{token});
|
||||
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,10 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
|
|||
// Convenience class to load a model and run inference.
|
||||
class GemmaEnv {
|
||||
public:
|
||||
// Calls the other constructor with *Args arguments initialized from argv.
|
||||
GemmaEnv(int argc, char** argv);
|
||||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
const AppArgs& app);
|
||||
|
||||
size_t MaxTokens() const { return inference_args_.max_tokens; }
|
||||
// Sets the maximum number of output tokens to generate.
|
||||
|
|
@ -81,7 +84,8 @@ class GemmaEnv {
|
|||
return loader_.ModelTrainingType();
|
||||
}
|
||||
int Verbosity() const { return app_.verbosity; }
|
||||
gcpp::RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
|
||||
std::mt19937& MutableGen() { return gen_; }
|
||||
KVCache& MutableKVCache() { return kv_cache_; }
|
||||
|
||||
|
|
@ -100,15 +104,14 @@ class GemmaEnv {
|
|||
std::unique_ptr<Gemma> model_;
|
||||
// The KV cache to use for inference.
|
||||
KVCache kv_cache_;
|
||||
gcpp::RuntimeConfig runtime_config_;
|
||||
RuntimeConfig runtime_config_;
|
||||
};
|
||||
|
||||
// Logs the inference speed in tokens/sec.
|
||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
||||
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||
gcpp::AppArgs& app);
|
||||
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue