diff --git a/README.md b/README.md index 5a453fa..ed22335 100644 --- a/README.md +++ b/README.md @@ -385,7 +385,6 @@ tokenizer : tokenizer.spm compressed_weights : 2b-it-sfp.sbs model : 2b-it weights : [no path specified] -max_tokens : 3072 max_generated_tokens : 2048 *Usage* diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 6c81aa3..26698c6 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -69,7 +69,6 @@ TEST(OptimizeTest, GradientDescent) { return token != ReverseSequenceSampler::kEndToken; }; RuntimeConfig runtime = { - .max_tokens = 32, .max_generated_tokens = 16, .temperature = 1.0f, .verbosity = 0, diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 4495c57..b59079a 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -81,15 +81,15 @@ int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) { size_t total_tokens = 0; const double time_start = hwy::platform::Now(); for (auto& [question, expected_answer] : queries_answers) { - const auto [answer, token_count] = env.QueryModel(question); - total_tokens += token_count; - if (answer.find(expected_answer) != std::string::npos) { + QueryResult result = env.QueryModel(question); + total_tokens += result.tokens_generated; + if (result.response.find(expected_answer) != std::string::npos) { correct_answers++; } else { std::cout << "Wrong!\n"; std::cout << "Input: " << question << "\n"; std::cout << "Expected: " << expected_answer << "\n"; - std::cout << "Output: " << answer << "\n\n" << std::flush; + std::cout << "Output: " << result.response << "\n\n" << std::flush; } } LogSpeedStats(time_start, total_tokens); @@ -108,9 +108,10 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) { prompt.append(ReadFileToString(text)); prompt.append("\nSummarize this text.\n"); const double time_start = hwy::platform::Now(); - const auto [answer, token_count] = env.QueryModel(prompt); - std::cout << answer.substr(prompt.size()) << "\n" << std::flush; - LogSpeedStats(time_start, token_count); + QueryResult result = env.QueryModel(prompt); + std::cout << result.response.substr(result.response_start_pos) << "\n" + << std::flush; + LogSpeedStats(time_start, result.tokens_generated); return EXIT_SUCCESS; } @@ -118,7 +119,6 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t batch_tokens) { std::string input = ReadFileToString(text); std::vector prompt = env.Tokenize(input); - prompt.resize(std::min(env.MaxTokens(), prompt.size())); std::cout << "Number of input tokens: " << prompt.size() << "\n"; const double time_start = hwy::platform::Now(); float total_entropy = 0.0f; @@ -156,11 +156,11 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file, while (std::getline(trivia_file, line)) { json data = json::parse(line); std::string q(data["question"]); - const auto [answer, token_count] = env.QueryModel(q); - std::cout << answer << "\n"; + QueryResult result = env.QueryModel(q); + std::cout << result.response << "\n"; bool correct = false; for (const std::string expected : data["answer"]["aliases"]) { - if (answer.find(expected) != std::string::npos) { + if (result.response.find(expected) != std::string::npos) { correct = true; break; } diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index dca4fb8..63553aa 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -18,14 +18,12 @@ #include #include -#include #include #include #include #include #include #include -#include // std::pair #include // Placeholder for internal header, do not modify. @@ -76,7 +74,6 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, } InitGenerator(inference, gen_); runtime_config_ = { - .max_tokens = inference.max_tokens, .max_generated_tokens = inference.max_generated_tokens, .temperature = inference.temperature, .verbosity = app.verbosity, @@ -99,21 +96,21 @@ GemmaEnv::GemmaEnv(int argc, char** argv) : GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv), MakeAppArgs(argc, argv)) {} -std::pair GemmaEnv::QueryModel( - const std::vector& tokens) { - std::string res; - size_t total_tokens = 0; +QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { + QueryResult result; - const BatchStreamFunc batch_stream_token = [&res, &total_tokens, this]( - size_t query_index, size_t pos, - int token, float) { - ++total_tokens; - res += StringFromTokens(std::vector{token}); - return true; - }; + const BatchStreamFunc batch_stream_token = + [&result, &tokens, this](size_t /*query_index*/, size_t /*pos*/, + int token, float /*score*/) { + ++result.tokens_generated; + result.response += StringFromTokens(std::vector{token}); + if (result.tokens_generated == tokens.size()) { + result.response_start_pos = result.response.size(); + } + return true; + }; if (runtime_config_.verbosity >= 2) { - std::cout << "Max tokens: " << runtime_config_.max_tokens - << "\tmax generated tokens: " + std::cout << "max generated tokens: " << runtime_config_.max_generated_tokens << "\ttemperature: " << runtime_config_.temperature << "\n"; } @@ -121,7 +118,7 @@ std::pair GemmaEnv::QueryModel( runtime_config_.batch_stream_token = batch_stream_token; model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], timing_info); - return {res, total_tokens}; + return result; } void GemmaEnv::QueryModel( @@ -134,27 +131,29 @@ void GemmaEnv::QueryModel( runtime_config_.stream_token = previous_stream_token; } -std::vector> GemmaEnv::BatchQueryModel( +std::vector GemmaEnv::BatchQueryModel( const QueriesPromptTokens& queries_prompt) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(num_queries != 0); - std::vector> res(num_queries); - std::fill(res.begin(), res.end(), std::make_pair("", 0)); - const BatchStreamFunc batch_stream_token = [&res, this](size_t query_index, - size_t pos, int token, - float) { + std::vector res(num_queries); + const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this]( + size_t query_index, size_t pos, + int token, float) { std::string token_text; HWY_ASSERT( model_->Tokenizer().Decode(std::vector{token}, &token_text)); - res[query_index].first.append(token_text); - res[query_index].second += 1; + res[query_index].response.append(token_text); + res[query_index].tokens_generated += 1; + if (res[query_index].tokens_generated == + queries_prompt[query_index].size()) { + res[query_index].response_start_pos = res[query_index].response.size(); + } return true; }; if (runtime_config_.verbosity >= 2) { - fprintf(stderr, - "Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", - runtime_config_.max_tokens, runtime_config_.max_generated_tokens, - runtime_config_.temperature, runtime_config_.prefill_tbatch_size, + fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", + runtime_config_.max_generated_tokens, runtime_config_.temperature, + runtime_config_.prefill_tbatch_size, runtime_config_.decode_qbatch_size); } @@ -178,21 +177,18 @@ std::vector> GemmaEnv::BatchQueryModel( return res; } -std::pair GemmaEnv::QueryModel(std::string& input) { - const std::vector prompt = - WrapAndTokenize(model_->Tokenizer(), model_->Info(), - /*pos=*/0, input); +QueryResult GemmaEnv::QueryModel(std::string& input) { + const std::vector prompt = WrapAndTokenize(input); return QueryModel(prompt); } -std::vector> GemmaEnv::BatchQueryModel( +std::vector GemmaEnv::BatchQueryModel( const std::vector& inputs) { std::vector> prompts; prompts.reserve(inputs.size()); for (auto& input : inputs) { std::string mutable_prompt = input; - prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(), - /*pos=*/0, mutable_prompt)); + prompts.push_back(WrapAndTokenize(mutable_prompt)); } std::vector prompt_vector; prompt_vector.reserve(prompts.size()); @@ -206,7 +202,7 @@ std::vector> GemmaEnv::BatchQueryModel( float GemmaEnv::CrossEntropy(const std::string& input) { std::vector prompt = Tokenize(input); prompt.insert(prompt.begin(), BOS_ID); - return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt, + return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt, MutableKVCache(), /*verbosity=*/0) / static_cast(input.size()); diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 03a9f07..397fc20 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -21,7 +21,6 @@ #include #include #include -#include #include #include "gemma/gemma.h" @@ -33,6 +32,14 @@ namespace gcpp { void InitGenerator(const InferenceArgs& inference, std::mt19937& gen); +// Return type for query model calls. +struct QueryResult { + std::string response; + size_t tokens_generated = 0; + // The position in the response at which the generated tokens start. + size_t response_start_pos = 0; +}; + // Convenience class to load a model and run inference. class GemmaEnv { public: @@ -41,8 +48,9 @@ class GemmaEnv { GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, const AppArgs& app); - size_t MaxTokens() const { return runtime_config_.max_tokens; } - // Sets the maximum number of output tokens to generate. + size_t MaxGeneratedTokens() const { + return runtime_config_.max_generated_tokens; + } void SetMaxGeneratedTokens(size_t max_generated_tokens) { runtime_config_.max_generated_tokens = max_generated_tokens; } @@ -59,6 +67,10 @@ class GemmaEnv { return tokens; } + std::vector WrapAndTokenize(std::string& input) const { + return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input); + } + std::string StringFromTokens(const std::vector& tokens) const { std::string string; HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string)); @@ -67,12 +79,12 @@ class GemmaEnv { // Runs inference on the given input and returns the top-1 result string and // the number of tokens that were generated. - std::pair QueryModel(const std::vector& tokens); - std::vector> BatchQueryModel( + QueryResult QueryModel(const std::vector& tokens); + std::vector BatchQueryModel( const QueriesPromptTokens& queries_prompt); // Adds turn structure to input, tokenizes and calls the above overload. - std::pair QueryModel(std::string& input); - std::vector> BatchQueryModel( + QueryResult QueryModel(std::string& input); + std::vector BatchQueryModel( const std::vector& inputs); // Runs inference on the given input and calls the callback for each token. diff --git a/evals/benchmarks.cc b/evals/benchmarks.cc index 5e7cf15..3cb3d3f 100644 --- a/evals/benchmarks.cc +++ b/evals/benchmarks.cc @@ -33,11 +33,11 @@ void RunPrompt(const std::string& original_prompt, benchmark::State& state) { size_t total_tokens = 0; for (auto s : state) { std::string prompt = original_prompt; // reset from original - auto [response, n] = s_env->QueryModel(prompt); + QueryResult result = s_env->QueryModel(prompt); if (s_env->Verbosity() != 0) { - fprintf(stdout, "|%s|\n", response.c_str()); + fprintf(stdout, "|%s|\n", result.response.c_str()); } - total_tokens += n; + total_tokens += result.tokens_generated; } state.SetItemsProcessed(total_tokens); diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 39600b5..870f84c 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -97,7 +97,7 @@ namespace gcpp { HWY_EXPORT(CallSoftmax); -float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, +float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens, const std::vector& prompt, KVCache& kv_cache, int verbosity) { const StreamFunc stream_token = [](int /*token*/, float) { return true; }; @@ -112,8 +112,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, size_t vocab_size) -> TokenAndProb { // input is logits, not yet probabilities HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size); - // We are called for each token, but pos starts at 1. Clamping max_tokens - // to prompt.size() should prevent overrun. + // We are called for each token, but pos starts at 1. Clamping + // max_generated_tokens to prompt.size() should prevent overrun. HWY_ASSERT(pos < prompt.size()); const int token = prompt[pos]; const float prob = probs[token]; @@ -136,10 +136,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, }; std::vector prompt0 = { prompt[0] }; - max_tokens = HWY_MIN(max_tokens, prompt.size()); + max_generated_tokens = HWY_MIN(max_generated_tokens, prompt.size()); RuntimeConfig runtime = { - .max_tokens = max_tokens, - .max_generated_tokens = max_tokens - 1, + .max_generated_tokens = max_generated_tokens - 1, .temperature = 0.0f, .verbosity = verbosity, .gen = nullptr, diff --git a/evals/cross_entropy.h b/evals/cross_entropy.h index 202ce22..fed224c 100644 --- a/evals/cross_entropy.h +++ b/evals/cross_entropy.h @@ -24,7 +24,7 @@ namespace gcpp { -float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, +float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens, const std::vector& prompt, KVCache& kv_cache, int verbosity); diff --git a/evals/debug_prompt.cc b/evals/debug_prompt.cc index 7ea32fa..2d02b3a 100644 --- a/evals/debug_prompt.cc +++ b/evals/debug_prompt.cc @@ -68,8 +68,9 @@ int Run(int argc, char** argv) { json_base[std::to_string(pos)][debug_key] = v; }; - const auto [answer, token_count] = env.QueryModel(prompt_args.prompt); - std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush; + QueryResult result = env.QueryModel(prompt_args.prompt); + std::cout << result.response.substr(result.response_start_pos) << "\n" + << std::flush; if (env.MutableConfig().layers_output) { std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out); diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 6f36b7a..98029fe 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -52,13 +52,13 @@ class GemmaTest : public ::testing::Test { if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { std::string mutable_prompt = prompt; - auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns. - return response; + QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns. + return result.response; } // Otherwise, do not use turn structure. const std::vector tokens = s_env->TokenizeAndPrependBOS(prompt); - auto [response, n] = s_env->QueryModel(tokens); - return response; + QueryResult result = s_env->QueryModel(tokens); + return result.response; } std::vector BatchGemmaReply( @@ -72,8 +72,8 @@ class GemmaTest : public ::testing::Test { // It would be good to make these tests more consistent. if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { - for (auto [response, n] : s_env->BatchQueryModel(inputs)) { - replies.push_back(response); + for (QueryResult result : s_env->BatchQueryModel(inputs)) { + replies.push_back(result.response); } return replies; } @@ -88,8 +88,8 @@ class GemmaTest : public ::testing::Test { prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); } QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); - for (auto [response, n] : s_env->BatchQueryModel(prompts)) { - replies.push_back(response); + for (const QueryResult& result : s_env->BatchQueryModel(prompts)) { + replies.push_back(result.response); } return replies; } @@ -167,7 +167,6 @@ TEST_F(GemmaTest, Multiturn) { return true; }; RuntimeConfig runtime_config{ - .max_tokens = 128, .max_generated_tokens = 64, .temperature = 0.0f, .verbosity = 2, diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index 729f7e9..d3618db 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -103,9 +103,7 @@ void Run(GemmaEnv& env, JsonArgs& json) { "What is start of the line with the correct answer? " "Do not include any justifications or explanations. Reply only with a " "letter."; - const std::vector prompt = - WrapAndTokenize(env.GetModel()->Tokenizer(), env.GetModel()->Info(), - /*pos=*/0, prompt_string); + const std::vector prompt = env.WrapAndTokenize(prompt_string); const size_t prompt_size = prompt.size(); std::vector predicted_token_ids; @@ -127,7 +125,6 @@ void Run(GemmaEnv& env, JsonArgs& json) { // confused with the word "A". gcpp::TimingInfo timing_info; gcpp::RuntimeConfig runtime_config = { - .max_tokens = env.MaxTokens(), .max_generated_tokens = 30, .temperature = 0.0f, .verbosity = env.Verbosity(), diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index cfd02e8..39d4f9c 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -87,7 +87,6 @@ int main(int argc, char** argv) { gcpp::TimingInfo timing_info; gcpp::RuntimeConfig runtime_config = { - .max_tokens = 1536, .max_generated_tokens = 1024, .temperature = 1.0, .verbosity = 0, diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 9e38490..5e3135b 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1117,34 +1117,15 @@ HWY_NOINLINE void Transformer( } template -void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, - size_t& prompt_size) { +void RangeChecks(size_t& max_generated_tokens, const size_t prompt_size) { if (!TConfig::kUseLocalAttention) { - if (max_tokens > TConfig::kSeqLen) { - fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n", - max_tokens, TConfig::kSeqLen); - max_tokens = static_cast(TConfig::kSeqLen); - } - } - - if (max_generated_tokens > max_tokens) { - fprintf(stderr, - "WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n", - max_generated_tokens, max_tokens); - max_generated_tokens = max_tokens - 1; - } - - if (!TConfig::kUseLocalAttention) { - if (prompt_size + max_generated_tokens > max_tokens) { + if (max_generated_tokens > TConfig::kSeqLen) { fprintf(stderr, - "WARNING: prompt_size %zu + max_generated_tokens %zu > " - "max_tokens %zu, truncating to ", - prompt_size, max_generated_tokens, max_tokens); - prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens); - fprintf(stderr, "%zu\n", prompt_size); + "WARNING: max_generated_tokens %zu > kSeqLen %d, truncating.\n", + max_generated_tokens, TConfig::kSeqLen); + max_generated_tokens = static_cast(TConfig::kSeqLen); } } - HWY_ASSERT(prompt_size > 0); } @@ -1262,17 +1243,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); size_t max_prompt_size = MaxQueryLength(queries_prompt); - size_t max_tokens = runtime_config.max_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens; - RangeChecks(max_tokens, max_generated_tokens, max_prompt_size); - for (size_t pos : queries_pos_copy) { - if (pos >= max_tokens) { - fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos, - max_tokens); - return; - } - } - + RangeChecks(max_generated_tokens, max_prompt_size); const SampleFunc sample_token = ChooseSampleFunc(runtime_config); // Prefill stops before min_prompt_size - 1 because the last prompt @@ -1315,7 +1287,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, } const double gen_start = hwy::platform::Now(); - for (size_t gen = 0; gen < HWY_MIN(max_tokens, max_generated_tokens); ++gen) { + for (size_t gen = 0; gen < max_generated_tokens; ++gen) { // Decode generates one token per query and increments queries_mutable_pos. Transformer( QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, diff --git a/gemma/gemma.h b/gemma/gemma.h index 654871d..6b74008 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -92,8 +92,7 @@ struct RuntimeConfig { return stream_token(token, prob); } - // Limits on the number of tokens generated. - size_t max_tokens; + // Limit on the number of tokens generated. size_t max_generated_tokens; // These defaults are overridden by InferenceArgs::CopyTo(*this): diff --git a/gemma/run.cc b/gemma/run.cc index 032b694..42e54a4 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -131,7 +131,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, return true; }; - while (abs_pos < args.max_tokens) { + while (true) { // Loop until user quits. tokens_generated_this_turn = 0; std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line); if (!std::cin) return; @@ -183,10 +183,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, timing_info); std::cout << "\n\n"; } - std::cout - << "max_tokens (" << args.max_tokens - << ") exceeded. Use a larger value if desired using the --max_tokens " - << "command line flag.\n"; } void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 3e2e9e8..2575b61 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -13,8 +13,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gemma/gemma.h" - #include #include #include @@ -22,6 +20,7 @@ #include "evals/benchmark_helper.h" #include "gemma/common.h" +#include "gemma/gemma.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -63,15 +62,13 @@ void PaliGemmaTest::InitVit(const std::string& path) { std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ Gemma& model = *(s_env->GetModel()); s_env->MutableGen().seed(0x12345678); - RuntimeConfig runtime_config = {.max_tokens = 1024, - .max_generated_tokens = 512, + RuntimeConfig runtime_config = {.max_generated_tokens = 512, .verbosity = 0, .gen = &s_env->MutableGen()}; runtime_config.image_tokens = image_tokens_.get(); size_t abs_pos = 0; std::string mutable_prompt = prompt_text; - std::vector tokens = - WrapAndTokenize(model.Tokenizer(), model.Info(), abs_pos, mutable_prompt); + std::vector tokens = s_env->WrapAndTokenize(mutable_prompt); std::string response; auto stream_token = [&](int token, float) { std::string token_text; diff --git a/util/app.h b/util/app.h index 69a1f88..e46e8df 100644 --- a/util/app.h +++ b/util/app.h @@ -207,7 +207,6 @@ struct InferenceArgs : public ArgsBase { InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } InferenceArgs() { Init(); }; - size_t max_tokens; size_t max_generated_tokens; size_t prefill_tbatch_size; @@ -220,21 +219,15 @@ struct InferenceArgs : public ArgsBase { // Returns error string or nullptr if OK. const char* Validate() const { - if (max_tokens > gcpp::kSeqLen) { - return "max_tokens is larger than the maximum sequence length (see " - "configs.h)."; - } - if (max_generated_tokens > max_tokens) { - return "Maximum number of generated tokens is larger than the maximum " - "total tokens."; + if (max_generated_tokens > gcpp::kSeqLen) { + return "max_generated_tokens is larger than the maximum sequence length " + "(see configs.h)."; } return nullptr; } template void ForEach(const Visitor& visitor) { - visitor(max_tokens, "max_tokens", size_t{3072}, - "Maximum number of tokens in prompt + generation."); visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, "Maximum number of tokens to generate."); @@ -255,12 +248,9 @@ struct InferenceArgs : public ArgsBase { } void CopyTo(RuntimeConfig& runtime_config) const { - runtime_config.max_tokens = max_tokens; runtime_config.max_generated_tokens = max_generated_tokens; - runtime_config.prefill_tbatch_size = prefill_tbatch_size; runtime_config.decode_qbatch_size = decode_qbatch_size; - runtime_config.temperature = temperature; } };