mirror of https://github.com/google/gemma.cpp.git
Fix Py binding/run_example: use GemmaEnv
PiperOrigin-RevId: 644318962
This commit is contained in:
parent
a07f60c9a1
commit
2ac47e4a06
|
|
@ -53,15 +53,12 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
: loader_(argc, argv),
|
const AppArgs& app)
|
||||||
inference_args_(argc, argv),
|
: loader_(loader),
|
||||||
app_(argc, argv),
|
inference_args_(inference),
|
||||||
|
app_(app),
|
||||||
pool_(app_.num_threads) {
|
pool_(app_.num_threads) {
|
||||||
{
|
|
||||||
// Placeholder for internal init, do not modify.
|
|
||||||
}
|
|
||||||
|
|
||||||
// For many-core, pinning workers to cores helps.
|
// For many-core, pinning workers to cores helps.
|
||||||
if (app_.num_threads > 10) {
|
if (app_.num_threads > 10) {
|
||||||
gcpp::PinWorkersToCores(pool_);
|
gcpp::PinWorkersToCores(pool_);
|
||||||
|
|
@ -89,6 +86,15 @@ GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Note: the delegating ctor above is called before any other initializers here.
|
||||||
|
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||||
|
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
||||||
|
AppArgs(argc, argv)) {
|
||||||
|
{ // So that indentation matches expectations.
|
||||||
|
// Placeholder for internal init, do not modify.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||||
const std::vector<int>& tokens) {
|
const std::vector<int>& tokens) {
|
||||||
std::string res;
|
std::string res;
|
||||||
|
|
@ -98,10 +104,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||||
const StreamFunc stream_token = [&res, &total_tokens, &time_start, this](
|
const StreamFunc stream_token = [&res, &total_tokens, &time_start, this](
|
||||||
int token, float) {
|
int token, float) {
|
||||||
++total_tokens;
|
++total_tokens;
|
||||||
std::string token_text;
|
res += StringFromTokens(std::vector<int>{token});
|
||||||
HWY_ASSERT(
|
|
||||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
|
||||||
res += token_text;
|
|
||||||
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
||||||
LogSpeedStats(time_start, total_tokens);
|
LogSpeedStats(time_start, total_tokens);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,10 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
|
||||||
// Convenience class to load a model and run inference.
|
// Convenience class to load a model and run inference.
|
||||||
class GemmaEnv {
|
class GemmaEnv {
|
||||||
public:
|
public:
|
||||||
|
// Calls the other constructor with *Args arguments initialized from argv.
|
||||||
GemmaEnv(int argc, char** argv);
|
GemmaEnv(int argc, char** argv);
|
||||||
|
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
|
const AppArgs& app);
|
||||||
|
|
||||||
size_t MaxTokens() const { return inference_args_.max_tokens; }
|
size_t MaxTokens() const { return inference_args_.max_tokens; }
|
||||||
// Sets the maximum number of output tokens to generate.
|
// Sets the maximum number of output tokens to generate.
|
||||||
|
|
@ -81,7 +84,8 @@ class GemmaEnv {
|
||||||
return loader_.ModelTrainingType();
|
return loader_.ModelTrainingType();
|
||||||
}
|
}
|
||||||
int Verbosity() const { return app_.verbosity; }
|
int Verbosity() const { return app_.verbosity; }
|
||||||
gcpp::RuntimeConfig& MutableConfig() { return runtime_config_; }
|
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||||
|
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
|
||||||
std::mt19937& MutableGen() { return gen_; }
|
std::mt19937& MutableGen() { return gen_; }
|
||||||
KVCache& MutableKVCache() { return kv_cache_; }
|
KVCache& MutableKVCache() { return kv_cache_; }
|
||||||
|
|
||||||
|
|
@ -100,15 +104,14 @@ class GemmaEnv {
|
||||||
std::unique_ptr<Gemma> model_;
|
std::unique_ptr<Gemma> model_;
|
||||||
// The KV cache to use for inference.
|
// The KV cache to use for inference.
|
||||||
KVCache kv_cache_;
|
KVCache kv_cache_;
|
||||||
gcpp::RuntimeConfig runtime_config_;
|
RuntimeConfig runtime_config_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Logs the inference speed in tokens/sec.
|
// Logs the inference speed in tokens/sec.
|
||||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||||
|
|
||||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
||||||
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
||||||
gcpp::AppArgs& app);
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue