Introduce QueryResult in GemmaEnv and add a shortcut for WrapAndTokenize.

Remove max_tokens (and rely on only max_generated_tokens).

PiperOrigin-RevId: 685662260
This commit is contained in:
Daniel Keysers 2024-10-14 04:44:42 -07:00 committed by Copybara-Service
parent 2892e232e2
commit a4d6adbc43
17 changed files with 99 additions and 144 deletions

View File

@ -385,7 +385,6 @@ tokenizer : tokenizer.spm
compressed_weights : 2b-it-sfp.sbs
model : 2b-it
weights : [no path specified]
max_tokens : 3072
max_generated_tokens : 2048
*Usage*

View File

@ -69,7 +69,6 @@ TEST(OptimizeTest, GradientDescent) {
return token != ReverseSequenceSampler::kEndToken;
};
RuntimeConfig runtime = {
.max_tokens = 32,
.max_generated_tokens = 16,
.temperature = 1.0f,
.verbosity = 0,

View File

@ -81,15 +81,15 @@ int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
size_t total_tokens = 0;
const double time_start = hwy::platform::Now();
for (auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] = env.QueryModel(question);
total_tokens += token_count;
if (answer.find(expected_answer) != std::string::npos) {
QueryResult result = env.QueryModel(question);
total_tokens += result.tokens_generated;
if (result.response.find(expected_answer) != std::string::npos) {
correct_answers++;
} else {
std::cout << "Wrong!\n";
std::cout << "Input: " << question << "\n";
std::cout << "Expected: " << expected_answer << "\n";
std::cout << "Output: " << answer << "\n\n" << std::flush;
std::cout << "Output: " << result.response << "\n\n" << std::flush;
}
}
LogSpeedStats(time_start, total_tokens);
@ -108,9 +108,10 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
prompt.append(ReadFileToString(text));
prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now();
const auto [answer, token_count] = env.QueryModel(prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
LogSpeedStats(time_start, token_count);
QueryResult result = env.QueryModel(prompt);
std::cout << result.response.substr(result.response_start_pos) << "\n"
<< std::flush;
LogSpeedStats(time_start, result.tokens_generated);
return EXIT_SUCCESS;
}
@ -118,7 +119,6 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t batch_tokens) {
std::string input = ReadFileToString(text);
std::vector<int> prompt = env.Tokenize(input);
prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
std::cout << "Number of input tokens: " << prompt.size() << "\n";
const double time_start = hwy::platform::Now();
float total_entropy = 0.0f;
@ -156,11 +156,11 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
while (std::getline(trivia_file, line)) {
json data = json::parse(line);
std::string q(data["question"]);
const auto [answer, token_count] = env.QueryModel(q);
std::cout << answer << "\n";
QueryResult result = env.QueryModel(q);
std::cout << result.response << "\n";
bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) {
if (answer.find(expected) != std::string::npos) {
if (result.response.find(expected) != std::string::npos) {
correct = true;
break;
}

View File

@ -18,14 +18,12 @@
#include <stdio.h>
#include <time.h>
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <memory>
#include <ostream>
#include <random>
#include <string>
#include <utility> // std::pair
#include <vector>
// Placeholder for internal header, do not modify.
@ -76,7 +74,6 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
}
InitGenerator(inference, gen_);
runtime_config_ = {
.max_tokens = inference.max_tokens,
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.verbosity = app.verbosity,
@ -99,21 +96,21 @@ GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
MakeAppArgs(argc, argv)) {}
std::pair<std::string, size_t> GemmaEnv::QueryModel(
const std::vector<int>& tokens) {
std::string res;
size_t total_tokens = 0;
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result;
const BatchStreamFunc batch_stream_token = [&res, &total_tokens, this](
size_t query_index, size_t pos,
int token, float) {
++total_tokens;
res += StringFromTokens(std::vector<int>{token});
const BatchStreamFunc batch_stream_token =
[&result, &tokens, this](size_t /*query_index*/, size_t /*pos*/,
int token, float /*score*/) {
++result.tokens_generated;
result.response += StringFromTokens(std::vector<int>{token});
if (result.tokens_generated == tokens.size()) {
result.response_start_pos = result.response.size();
}
return true;
};
if (runtime_config_.verbosity >= 2) {
std::cout << "Max tokens: " << runtime_config_.max_tokens
<< "\tmax generated tokens: "
std::cout << "max generated tokens: "
<< runtime_config_.max_generated_tokens
<< "\ttemperature: " << runtime_config_.temperature << "\n";
}
@ -121,7 +118,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
runtime_config_.batch_stream_token = batch_stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info);
return {res, total_tokens};
return result;
}
void GemmaEnv::QueryModel(
@ -134,27 +131,29 @@ void GemmaEnv::QueryModel(
runtime_config_.stream_token = previous_stream_token;
}
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const QueriesPromptTokens& queries_prompt) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries != 0);
std::vector<std::pair<std::string, size_t>> res(num_queries);
std::fill(res.begin(), res.end(), std::make_pair("", 0));
const BatchStreamFunc batch_stream_token = [&res, this](size_t query_index,
size_t pos, int token,
float) {
std::vector<QueryResult> res(num_queries);
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
size_t query_index, size_t pos,
int token, float) {
std::string token_text;
HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
res[query_index].first.append(token_text);
res[query_index].second += 1;
res[query_index].response.append(token_text);
res[query_index].tokens_generated += 1;
if (res[query_index].tokens_generated ==
queries_prompt[query_index].size()) {
res[query_index].response_start_pos = res[query_index].response.size();
}
return true;
};
if (runtime_config_.verbosity >= 2) {
fprintf(stderr,
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_tokens, runtime_config_.max_generated_tokens,
runtime_config_.temperature, runtime_config_.prefill_tbatch_size,
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature,
runtime_config_.prefill_tbatch_size,
runtime_config_.decode_qbatch_size);
}
@ -178,21 +177,18 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
return res;
}
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt =
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, input);
QueryResult GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt = WrapAndTokenize(input);
return QueryModel(prompt);
}
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) {
std::vector<std::vector<int>> prompts;
prompts.reserve(inputs.size());
for (auto& input : inputs) {
std::string mutable_prompt = input;
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, mutable_prompt));
prompts.push_back(WrapAndTokenize(mutable_prompt));
}
std::vector<PromptTokens> prompt_vector;
prompt_vector.reserve(prompts.size());
@ -206,7 +202,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt,
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
MutableKVCache(),
/*verbosity=*/0) /
static_cast<int>(input.size());

View File

@ -21,7 +21,6 @@
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "gemma/gemma.h"
@ -33,6 +32,14 @@ namespace gcpp {
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
// Return type for query model calls.
struct QueryResult {
std::string response;
size_t tokens_generated = 0;
// The position in the response at which the generated tokens start.
size_t response_start_pos = 0;
};
// Convenience class to load a model and run inference.
class GemmaEnv {
public:
@ -41,8 +48,9 @@ class GemmaEnv {
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app);
size_t MaxTokens() const { return runtime_config_.max_tokens; }
// Sets the maximum number of output tokens to generate.
size_t MaxGeneratedTokens() const {
return runtime_config_.max_generated_tokens;
}
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
runtime_config_.max_generated_tokens = max_generated_tokens;
}
@ -59,6 +67,10 @@ class GemmaEnv {
return tokens;
}
std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
}
std::string StringFromTokens(const std::vector<int>& tokens) const {
std::string string;
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
@ -67,12 +79,12 @@ class GemmaEnv {
// Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated.
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
QueryResult QueryModel(const std::vector<int>& tokens);
std::vector<QueryResult> BatchQueryModel(
const QueriesPromptTokens& queries_prompt);
// Adds turn structure to input, tokenizes and calls the above overload.
std::pair<std::string, size_t> QueryModel(std::string& input);
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
QueryResult QueryModel(std::string& input);
std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs);
// Runs inference on the given input and calls the callback for each token.

View File

@ -33,11 +33,11 @@ void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
size_t total_tokens = 0;
for (auto s : state) {
std::string prompt = original_prompt; // reset from original
auto [response, n] = s_env->QueryModel(prompt);
QueryResult result = s_env->QueryModel(prompt);
if (s_env->Verbosity() != 0) {
fprintf(stdout, "|%s|\n", response.c_str());
fprintf(stdout, "|%s|\n", result.response.c_str());
}
total_tokens += n;
total_tokens += result.tokens_generated;
}
state.SetItemsProcessed(total_tokens);

View File

@ -97,7 +97,7 @@ namespace gcpp {
HWY_EXPORT(CallSoftmax);
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity) {
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
@ -112,8 +112,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
size_t vocab_size) -> TokenAndProb {
// input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
// We are called for each token, but pos starts at 1. Clamping max_tokens
// to prompt.size() should prevent overrun.
// We are called for each token, but pos starts at 1. Clamping
// max_generated_tokens to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size());
const int token = prompt[pos];
const float prob = probs[token];
@ -136,10 +136,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
};
std::vector<int> prompt0 = { prompt[0] };
max_tokens = HWY_MIN(max_tokens, prompt.size());
max_generated_tokens = HWY_MIN(max_generated_tokens, prompt.size());
RuntimeConfig runtime = {
.max_tokens = max_tokens,
.max_generated_tokens = max_tokens - 1,
.max_generated_tokens = max_generated_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,

View File

@ -24,7 +24,7 @@
namespace gcpp {
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity);

View File

@ -68,8 +68,9 @@ int Run(int argc, char** argv) {
json_base[std::to_string(pos)][debug_key] = v;
};
const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush;
QueryResult result = env.QueryModel(prompt_args.prompt);
std::cout << result.response.substr(result.response_start_pos) << "\n"
<< std::flush;
if (env.MutableConfig().layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);

View File

@ -52,13 +52,13 @@ class GemmaTest : public ::testing::Test {
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
std::string mutable_prompt = prompt;
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
return response;
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
return result.response;
}
// Otherwise, do not use turn structure.
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
auto [response, n] = s_env->QueryModel(tokens);
return response;
QueryResult result = s_env->QueryModel(tokens);
return result.response;
}
std::vector<std::string> BatchGemmaReply(
@ -72,8 +72,8 @@ class GemmaTest : public ::testing::Test {
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
replies.push_back(response);
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
return replies;
}
@ -88,8 +88,8 @@ class GemmaTest : public ::testing::Test {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
for (auto [response, n] : s_env->BatchQueryModel(prompts)) {
replies.push_back(response);
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
replies.push_back(result.response);
}
return replies;
}
@ -167,7 +167,6 @@ TEST_F(GemmaTest, Multiturn) {
return true;
};
RuntimeConfig runtime_config{
.max_tokens = 128,
.max_generated_tokens = 64,
.temperature = 0.0f,
.verbosity = 2,

View File

@ -103,9 +103,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
"What is start of the line with the correct answer? "
"Do not include any justifications or explanations. Reply only with a "
"letter.";
const std::vector<int> prompt =
WrapAndTokenize(env.GetModel()->Tokenizer(), env.GetModel()->Info(),
/*pos=*/0, prompt_string);
const std::vector<int> prompt = env.WrapAndTokenize(prompt_string);
const size_t prompt_size = prompt.size();
std::vector<int> predicted_token_ids;
@ -127,7 +125,6 @@ void Run(GemmaEnv& env, JsonArgs& json) {
// confused with the word "A".
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = env.MaxTokens(),
.max_generated_tokens = 30,
.temperature = 0.0f,
.verbosity = env.Verbosity(),

View File

@ -87,7 +87,6 @@ int main(int argc, char** argv) {
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = 1536,
.max_generated_tokens = 1024,
.temperature = 1.0,
.verbosity = 0,

View File

@ -1117,34 +1117,15 @@ HWY_NOINLINE void Transformer(
}
template <class TConfig>
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
size_t& prompt_size) {
void RangeChecks(size_t& max_generated_tokens, const size_t prompt_size) {
if (!TConfig::kUseLocalAttention) {
if (max_tokens > TConfig::kSeqLen) {
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
max_tokens, TConfig::kSeqLen);
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
}
}
if (max_generated_tokens > max_tokens) {
if (max_generated_tokens > TConfig::kSeqLen) {
fprintf(stderr,
"WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n",
max_generated_tokens, max_tokens);
max_generated_tokens = max_tokens - 1;
}
if (!TConfig::kUseLocalAttention) {
if (prompt_size + max_generated_tokens > max_tokens) {
fprintf(stderr,
"WARNING: prompt_size %zu + max_generated_tokens %zu > "
"max_tokens %zu, truncating to ",
prompt_size, max_generated_tokens, max_tokens);
prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens);
fprintf(stderr, "%zu\n", prompt_size);
"WARNING: max_generated_tokens %zu > kSeqLen %d, truncating.\n",
max_generated_tokens, TConfig::kSeqLen);
max_generated_tokens = static_cast<size_t>(TConfig::kSeqLen);
}
}
HWY_ASSERT(prompt_size > 0);
}
@ -1262,17 +1243,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
size_t max_prompt_size = MaxQueryLength(queries_prompt);
size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
for (size_t pos : queries_pos_copy) {
if (pos >= max_tokens) {
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
max_tokens);
return;
}
}
RangeChecks<TConfig>(max_generated_tokens, max_prompt_size);
const SampleFunc sample_token = ChooseSampleFunc<TConfig>(runtime_config);
// Prefill stops before min_prompt_size - 1 because the last prompt
@ -1315,7 +1287,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
}
const double gen_start = hwy::platform::Now();
for (size_t gen = 0; gen < HWY_MIN(max_tokens, max_generated_tokens); ++gen) {
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
// Decode generates one token per query and increments queries_mutable_pos.
Transformer<TConfig>(
QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,

View File

@ -92,8 +92,7 @@ struct RuntimeConfig {
return stream_token(token, prob);
}
// Limits on the number of tokens generated.
size_t max_tokens;
// Limit on the number of tokens generated.
size_t max_generated_tokens;
// These defaults are overridden by InferenceArgs::CopyTo(*this):

View File

@ -131,7 +131,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
return true;
};
while (abs_pos < args.max_tokens) {
while (true) { // Loop until user quits.
tokens_generated_this_turn = 0;
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
if (!std::cin) return;
@ -183,10 +183,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
timing_info);
std::cout << "\n\n";
}
std::cout
<< "max_tokens (" << args.max_tokens
<< ") exceeded. Use a larger value if desired using the --max_tokens "
<< "command line flag.\n";
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {

View File

@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/gemma.h"
#include <cstdio>
#include <memory>
#include <string>
@ -22,6 +20,7 @@
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
@ -63,15 +62,13 @@ void PaliGemmaTest::InitVit(const std::string& path) {
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
Gemma& model = *(s_env->GetModel());
s_env->MutableGen().seed(0x12345678);
RuntimeConfig runtime_config = {.max_tokens = 1024,
.max_generated_tokens = 512,
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
.verbosity = 0,
.gen = &s_env->MutableGen()};
runtime_config.image_tokens = image_tokens_.get();
size_t abs_pos = 0;
std::string mutable_prompt = prompt_text;
std::vector<int> tokens =
WrapAndTokenize(model.Tokenizer(), model.Info(), abs_pos, mutable_prompt);
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
std::string response;
auto stream_token = [&](int token, float) {
std::string token_text;

View File

@ -207,7 +207,6 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
InferenceArgs() { Init(); };
size_t max_tokens;
size_t max_generated_tokens;
size_t prefill_tbatch_size;
@ -220,21 +219,15 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_tokens > gcpp::kSeqLen) {
return "max_tokens is larger than the maximum sequence length (see "
"configs.h).";
}
if (max_generated_tokens > max_tokens) {
return "Maximum number of generated tokens is larger than the maximum "
"total tokens.";
if (max_generated_tokens > gcpp::kSeqLen) {
return "max_generated_tokens is larger than the maximum sequence length "
"(see configs.h).";
}
return nullptr;
}
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(max_tokens, "max_tokens", size_t{3072},
"Maximum number of tokens in prompt + generation.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");
@ -255,12 +248,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
}
void CopyTo(RuntimeConfig& runtime_config) const {
runtime_config.max_tokens = max_tokens;
runtime_config.max_generated_tokens = max_generated_tokens;
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size;
runtime_config.temperature = temperature;
}
};