diff --git a/gemma.h b/gemma.h index 3528b50..f6361e1 100644 --- a/gemma.h +++ b/gemma.h @@ -64,13 +64,14 @@ struct KVCache { enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; -// TODO: incorporate -struct InferenceParams { - Model model; +// TODO: Incorporate this +struct Runtime { + Model model_type; ModelTraining model_training; - size_t max_generated_tokens; size_t max_tokens; + size_t max_generated_tokens; float temperature; + std::mt19937 gen; }; struct LoaderArgs : public ArgsBase { @@ -212,7 +213,7 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& g, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); constexpr int EOS_ID = 1;