diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index a972154..fd6c762 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -21,20 +21,27 @@ std::vector tokenize( int main(int argc, char** argv) { gcpp::LoaderArgs loader(argc, argv); - // A rough heuristic for a reasonable number of threads given hardware - // concurrency estimate + + // A rough heuristic number of threads to use size_t num_threads = static_cast(std::clamp( static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); hwy::ThreadPool pool(num_threads); - hwy::ThreadPool inner_pool(0); + + // Instantiate model gcpp::Gemma model(loader, pool); + + // Setup random number generator std::mt19937 gen; std::random_device rd; gen.seed(rd()); + + // Tokenize instruction std::vector tokens = tokenize("Write a greeting to the world.", model.Tokenizer()); size_t ntokens = tokens.size(); size_t pos = 0; + + // Callback auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()]( int token, float) { ++pos; @@ -43,17 +50,16 @@ int main(int argc, char** argv) { } else if (token != gcpp::EOS_ID) { std::string token_text; HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); - if (pos == ntokens + 1) { - // first token of response - token_text.erase(0, token_text.find_first_not_of(" \t\n\n")); - } std::cout << token_text << std::flush; } return true; }; - GenerateGemma( - model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, - /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token, - [](int) { return true; }, gen, 0); + + GenerateGemma(model, + {.max_tokens = 2048, + .max_generated_tokens = 1024, + .temperature = 1.0, + .verbosity = 0}, + tokens, /*KV cache position = */ 0, pool, stream_token, gen); std::cout << std::endl; } diff --git a/gemma.cc b/gemma.cc index f080dda..9c2df6c 100644 --- a/gemma.cc +++ b/gemma.cc @@ -844,5 +844,16 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } +void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + const StreamFunc& stream_token, std::mt19937& gen) { + hwy::ThreadPool inner_pool(0); + GenerateGemma( + gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, + runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, + stream_token, [](int) { return true; }, gen, runtime_config.verbosity); +} + } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma.h b/gemma.h index 1a6ca07..3bb4c28 100644 --- a/gemma.h +++ b/gemma.h @@ -64,17 +64,11 @@ struct KVCache { enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; -// TODO: Incorporate this -struct Runtime { - // TODO: In the future we may fold ModelTraining into Model. - // As we add more variations of model_type, the cartesian set becomes - // unwieldy. - Model model_type; - ModelTraining model_training; +struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; float temperature; - std::mt19937 gen; + int verbosity; }; struct LoaderArgs : public ArgsBase { @@ -205,9 +199,6 @@ struct Gemma { gcpp::ModelTraining model_training; }; -struct LoaderArgs; // forward declaration -void CreateGemma(const LoaderArgs& args, hwy::ThreadPool& pool, Gemma& model); - KVCache CreateKVCache(Model type); // convenient workaround for now KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); @@ -223,6 +214,15 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); +// Convenience function for the common case: +// - Bundle runtime parameters as RuntimeConfig +// - No threadpools within threadpools (inner_pool = dummy) +// - All tokens accepted +void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + const StreamFunc& stream_token, std::mt19937& gen); + constexpr int EOS_ID = 1; } // namespace gcpp