Introduce QueryResult in GemmaEnv and add a shortcut for WrapAndTokenize.

Remove max_tokens (and rely on only max_generated_tokens).

PiperOrigin-RevId: 685662260
This commit is contained in:
Daniel Keysers 2024-10-14 04:44:42 -07:00 committed by Copybara-Service
parent 2892e232e2
commit a4d6adbc43
17 changed files with 99 additions and 144 deletions

View File

@ -385,7 +385,6 @@ tokenizer : tokenizer.spm
compressed_weights : 2b-it-sfp.sbs compressed_weights : 2b-it-sfp.sbs
model : 2b-it model : 2b-it
weights : [no path specified] weights : [no path specified]
max_tokens : 3072
max_generated_tokens : 2048 max_generated_tokens : 2048
*Usage* *Usage*

View File

@ -69,7 +69,6 @@ TEST(OptimizeTest, GradientDescent) {
return token != ReverseSequenceSampler::kEndToken; return token != ReverseSequenceSampler::kEndToken;
}; };
RuntimeConfig runtime = { RuntimeConfig runtime = {
.max_tokens = 32,
.max_generated_tokens = 16, .max_generated_tokens = 16,
.temperature = 1.0f, .temperature = 1.0f,
.verbosity = 0, .verbosity = 0,

View File

@ -81,15 +81,15 @@ int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
size_t total_tokens = 0; size_t total_tokens = 0;
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
for (auto& [question, expected_answer] : queries_answers) { for (auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] = env.QueryModel(question); QueryResult result = env.QueryModel(question);
total_tokens += token_count; total_tokens += result.tokens_generated;
if (answer.find(expected_answer) != std::string::npos) { if (result.response.find(expected_answer) != std::string::npos) {
correct_answers++; correct_answers++;
} else { } else {
std::cout << "Wrong!\n"; std::cout << "Wrong!\n";
std::cout << "Input: " << question << "\n"; std::cout << "Input: " << question << "\n";
std::cout << "Expected: " << expected_answer << "\n"; std::cout << "Expected: " << expected_answer << "\n";
std::cout << "Output: " << answer << "\n\n" << std::flush; std::cout << "Output: " << result.response << "\n\n" << std::flush;
} }
} }
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
@ -108,9 +108,10 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
prompt.append(ReadFileToString(text)); prompt.append(ReadFileToString(text));
prompt.append("\nSummarize this text.\n"); prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
const auto [answer, token_count] = env.QueryModel(prompt); QueryResult result = env.QueryModel(prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush; std::cout << result.response.substr(result.response_start_pos) << "\n"
LogSpeedStats(time_start, token_count); << std::flush;
LogSpeedStats(time_start, result.tokens_generated);
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
@ -118,7 +119,6 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t batch_tokens) { size_t batch_tokens) {
std::string input = ReadFileToString(text); std::string input = ReadFileToString(text);
std::vector<int> prompt = env.Tokenize(input); std::vector<int> prompt = env.Tokenize(input);
prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
std::cout << "Number of input tokens: " << prompt.size() << "\n"; std::cout << "Number of input tokens: " << prompt.size() << "\n";
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
float total_entropy = 0.0f; float total_entropy = 0.0f;
@ -156,11 +156,11 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
while (std::getline(trivia_file, line)) { while (std::getline(trivia_file, line)) {
json data = json::parse(line); json data = json::parse(line);
std::string q(data["question"]); std::string q(data["question"]);
const auto [answer, token_count] = env.QueryModel(q); QueryResult result = env.QueryModel(q);
std::cout << answer << "\n"; std::cout << result.response << "\n";
bool correct = false; bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) { for (const std::string expected : data["answer"]["aliases"]) {
if (answer.find(expected) != std::string::npos) { if (result.response.find(expected) != std::string::npos) {
correct = true; correct = true;
break; break;
} }

View File

@ -18,14 +18,12 @@
#include <stdio.h> #include <stdio.h>
#include <time.h> #include <time.h>
#include <algorithm>
#include <cstdio> #include <cstdio>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <ostream> #include <ostream>
#include <random> #include <random>
#include <string> #include <string>
#include <utility> // std::pair
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. // Placeholder for internal header, do not modify.
@ -76,7 +74,6 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
} }
InitGenerator(inference, gen_); InitGenerator(inference, gen_);
runtime_config_ = { runtime_config_ = {
.max_tokens = inference.max_tokens,
.max_generated_tokens = inference.max_generated_tokens, .max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature, .temperature = inference.temperature,
.verbosity = app.verbosity, .verbosity = app.verbosity,
@ -99,21 +96,21 @@ GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv), : GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
MakeAppArgs(argc, argv)) {} MakeAppArgs(argc, argv)) {}
std::pair<std::string, size_t> GemmaEnv::QueryModel( QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
const std::vector<int>& tokens) { QueryResult result;
std::string res;
size_t total_tokens = 0;
const BatchStreamFunc batch_stream_token = [&res, &total_tokens, this]( const BatchStreamFunc batch_stream_token =
size_t query_index, size_t pos, [&result, &tokens, this](size_t /*query_index*/, size_t /*pos*/,
int token, float) { int token, float /*score*/) {
++total_tokens; ++result.tokens_generated;
res += StringFromTokens(std::vector<int>{token}); result.response += StringFromTokens(std::vector<int>{token});
if (result.tokens_generated == tokens.size()) {
result.response_start_pos = result.response.size();
}
return true; return true;
}; };
if (runtime_config_.verbosity >= 2) { if (runtime_config_.verbosity >= 2) {
std::cout << "Max tokens: " << runtime_config_.max_tokens std::cout << "max generated tokens: "
<< "\tmax generated tokens: "
<< runtime_config_.max_generated_tokens << runtime_config_.max_generated_tokens
<< "\ttemperature: " << runtime_config_.temperature << "\n"; << "\ttemperature: " << runtime_config_.temperature << "\n";
} }
@ -121,7 +118,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
runtime_config_.batch_stream_token = batch_stream_token; runtime_config_.batch_stream_token = batch_stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info); timing_info);
return {res, total_tokens}; return result;
} }
void GemmaEnv::QueryModel( void GemmaEnv::QueryModel(
@ -134,27 +131,29 @@ void GemmaEnv::QueryModel(
runtime_config_.stream_token = previous_stream_token; runtime_config_.stream_token = previous_stream_token;
} }
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel( std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const QueriesPromptTokens& queries_prompt) { const QueriesPromptTokens& queries_prompt) {
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries != 0); HWY_ASSERT(num_queries != 0);
std::vector<std::pair<std::string, size_t>> res(num_queries); std::vector<QueryResult> res(num_queries);
std::fill(res.begin(), res.end(), std::make_pair("", 0)); const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
const BatchStreamFunc batch_stream_token = [&res, this](size_t query_index, size_t query_index, size_t pos,
size_t pos, int token, int token, float) {
float) {
std::string token_text; std::string token_text;
HWY_ASSERT( HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text)); model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
res[query_index].first.append(token_text); res[query_index].response.append(token_text);
res[query_index].second += 1; res[query_index].tokens_generated += 1;
if (res[query_index].tokens_generated ==
queries_prompt[query_index].size()) {
res[query_index].response_start_pos = res[query_index].response.size();
}
return true; return true;
}; };
if (runtime_config_.verbosity >= 2) { if (runtime_config_.verbosity >= 2) {
fprintf(stderr, fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", runtime_config_.max_generated_tokens, runtime_config_.temperature,
runtime_config_.max_tokens, runtime_config_.max_generated_tokens, runtime_config_.prefill_tbatch_size,
runtime_config_.temperature, runtime_config_.prefill_tbatch_size,
runtime_config_.decode_qbatch_size); runtime_config_.decode_qbatch_size);
} }
@ -178,21 +177,18 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
return res; return res;
} }
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) { QueryResult GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt = const std::vector<int> prompt = WrapAndTokenize(input);
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, input);
return QueryModel(prompt); return QueryModel(prompt);
} }
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel( std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) { const std::vector<std::string>& inputs) {
std::vector<std::vector<int>> prompts; std::vector<std::vector<int>> prompts;
prompts.reserve(inputs.size()); prompts.reserve(inputs.size());
for (auto& input : inputs) { for (auto& input : inputs) {
std::string mutable_prompt = input; std::string mutable_prompt = input;
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(), prompts.push_back(WrapAndTokenize(mutable_prompt));
/*pos=*/0, mutable_prompt));
} }
std::vector<PromptTokens> prompt_vector; std::vector<PromptTokens> prompt_vector;
prompt_vector.reserve(prompts.size()); prompt_vector.reserve(prompts.size());
@ -206,7 +202,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
float GemmaEnv::CrossEntropy(const std::string& input) { float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input); std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID); prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt, return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
MutableKVCache(), MutableKVCache(),
/*verbosity=*/0) / /*verbosity=*/0) /
static_cast<int>(input.size()); static_cast<int>(input.size());

View File

@ -21,7 +21,6 @@
#include <memory> #include <memory>
#include <random> #include <random>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "gemma/gemma.h" #include "gemma/gemma.h"
@ -33,6 +32,14 @@ namespace gcpp {
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen); void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
// Return type for query model calls.
struct QueryResult {
std::string response;
size_t tokens_generated = 0;
// The position in the response at which the generated tokens start.
size_t response_start_pos = 0;
};
// Convenience class to load a model and run inference. // Convenience class to load a model and run inference.
class GemmaEnv { class GemmaEnv {
public: public:
@ -41,8 +48,9 @@ class GemmaEnv {
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app); const AppArgs& app);
size_t MaxTokens() const { return runtime_config_.max_tokens; } size_t MaxGeneratedTokens() const {
// Sets the maximum number of output tokens to generate. return runtime_config_.max_generated_tokens;
}
void SetMaxGeneratedTokens(size_t max_generated_tokens) { void SetMaxGeneratedTokens(size_t max_generated_tokens) {
runtime_config_.max_generated_tokens = max_generated_tokens; runtime_config_.max_generated_tokens = max_generated_tokens;
} }
@ -59,6 +67,10 @@ class GemmaEnv {
return tokens; return tokens;
} }
std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
}
std::string StringFromTokens(const std::vector<int>& tokens) const { std::string StringFromTokens(const std::vector<int>& tokens) const {
std::string string; std::string string;
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string)); HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
@ -67,12 +79,12 @@ class GemmaEnv {
// Runs inference on the given input and returns the top-1 result string and // Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated. // the number of tokens that were generated.
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens); QueryResult QueryModel(const std::vector<int>& tokens);
std::vector<std::pair<std::string, size_t>> BatchQueryModel( std::vector<QueryResult> BatchQueryModel(
const QueriesPromptTokens& queries_prompt); const QueriesPromptTokens& queries_prompt);
// Adds turn structure to input, tokenizes and calls the above overload. // Adds turn structure to input, tokenizes and calls the above overload.
std::pair<std::string, size_t> QueryModel(std::string& input); QueryResult QueryModel(std::string& input);
std::vector<std::pair<std::string, size_t>> BatchQueryModel( std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs); const std::vector<std::string>& inputs);
// Runs inference on the given input and calls the callback for each token. // Runs inference on the given input and calls the callback for each token.

View File

@ -33,11 +33,11 @@ void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
size_t total_tokens = 0; size_t total_tokens = 0;
for (auto s : state) { for (auto s : state) {
std::string prompt = original_prompt; // reset from original std::string prompt = original_prompt; // reset from original
auto [response, n] = s_env->QueryModel(prompt); QueryResult result = s_env->QueryModel(prompt);
if (s_env->Verbosity() != 0) { if (s_env->Verbosity() != 0) {
fprintf(stdout, "|%s|\n", response.c_str()); fprintf(stdout, "|%s|\n", result.response.c_str());
} }
total_tokens += n; total_tokens += result.tokens_generated;
} }
state.SetItemsProcessed(total_tokens); state.SetItemsProcessed(total_tokens);

View File

@ -97,7 +97,7 @@ namespace gcpp {
HWY_EXPORT(CallSoftmax); HWY_EXPORT(CallSoftmax);
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity) { int verbosity) {
const StreamFunc stream_token = [](int /*token*/, float) { return true; }; const StreamFunc stream_token = [](int /*token*/, float) { return true; };
@ -112,8 +112,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
size_t vocab_size) -> TokenAndProb { size_t vocab_size) -> TokenAndProb {
// input is logits, not yet probabilities // input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size); HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
// We are called for each token, but pos starts at 1. Clamping max_tokens // We are called for each token, but pos starts at 1. Clamping
// to prompt.size() should prevent overrun. // max_generated_tokens to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size()); HWY_ASSERT(pos < prompt.size());
const int token = prompt[pos]; const int token = prompt[pos];
const float prob = probs[token]; const float prob = probs[token];
@ -136,10 +136,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
}; };
std::vector<int> prompt0 = { prompt[0] }; std::vector<int> prompt0 = { prompt[0] };
max_tokens = HWY_MIN(max_tokens, prompt.size()); max_generated_tokens = HWY_MIN(max_generated_tokens, prompt.size());
RuntimeConfig runtime = { RuntimeConfig runtime = {
.max_tokens = max_tokens, .max_generated_tokens = max_generated_tokens - 1,
.max_generated_tokens = max_tokens - 1,
.temperature = 0.0f, .temperature = 0.0f,
.verbosity = verbosity, .verbosity = verbosity,
.gen = nullptr, .gen = nullptr,

View File

@ -24,7 +24,7 @@
namespace gcpp { namespace gcpp {
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity); int verbosity);

View File

@ -68,8 +68,9 @@ int Run(int argc, char** argv) {
json_base[std::to_string(pos)][debug_key] = v; json_base[std::to_string(pos)][debug_key] = v;
}; };
const auto [answer, token_count] = env.QueryModel(prompt_args.prompt); QueryResult result = env.QueryModel(prompt_args.prompt);
std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush; std::cout << result.response.substr(result.response_start_pos) << "\n"
<< std::flush;
if (env.MutableConfig().layers_output) { if (env.MutableConfig().layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out); std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);

View File

@ -52,13 +52,13 @@ class GemmaTest : public ::testing::Test {
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
std::string mutable_prompt = prompt; std::string mutable_prompt = prompt;
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns. QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
return response; return result.response;
} }
// Otherwise, do not use turn structure. // Otherwise, do not use turn structure.
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt); const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
auto [response, n] = s_env->QueryModel(tokens); QueryResult result = s_env->QueryModel(tokens);
return response; return result.response;
} }
std::vector<std::string> BatchGemmaReply( std::vector<std::string> BatchGemmaReply(
@ -72,8 +72,8 @@ class GemmaTest : public ::testing::Test {
// It would be good to make these tests more consistent. // It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
for (auto [response, n] : s_env->BatchQueryModel(inputs)) { for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(response); replies.push_back(result.response);
} }
return replies; return replies;
} }
@ -88,8 +88,8 @@ class GemmaTest : public ::testing::Test {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
} }
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
for (auto [response, n] : s_env->BatchQueryModel(prompts)) { for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
replies.push_back(response); replies.push_back(result.response);
} }
return replies; return replies;
} }
@ -167,7 +167,6 @@ TEST_F(GemmaTest, Multiturn) {
return true; return true;
}; };
RuntimeConfig runtime_config{ RuntimeConfig runtime_config{
.max_tokens = 128,
.max_generated_tokens = 64, .max_generated_tokens = 64,
.temperature = 0.0f, .temperature = 0.0f,
.verbosity = 2, .verbosity = 2,

View File

@ -103,9 +103,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
"What is start of the line with the correct answer? " "What is start of the line with the correct answer? "
"Do not include any justifications or explanations. Reply only with a " "Do not include any justifications or explanations. Reply only with a "
"letter."; "letter.";
const std::vector<int> prompt = const std::vector<int> prompt = env.WrapAndTokenize(prompt_string);
WrapAndTokenize(env.GetModel()->Tokenizer(), env.GetModel()->Info(),
/*pos=*/0, prompt_string);
const size_t prompt_size = prompt.size(); const size_t prompt_size = prompt.size();
std::vector<int> predicted_token_ids; std::vector<int> predicted_token_ids;
@ -127,7 +125,6 @@ void Run(GemmaEnv& env, JsonArgs& json) {
// confused with the word "A". // confused with the word "A".
gcpp::TimingInfo timing_info; gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = { gcpp::RuntimeConfig runtime_config = {
.max_tokens = env.MaxTokens(),
.max_generated_tokens = 30, .max_generated_tokens = 30,
.temperature = 0.0f, .temperature = 0.0f,
.verbosity = env.Verbosity(), .verbosity = env.Verbosity(),

View File

@ -87,7 +87,6 @@ int main(int argc, char** argv) {
gcpp::TimingInfo timing_info; gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = { gcpp::RuntimeConfig runtime_config = {
.max_tokens = 1536,
.max_generated_tokens = 1024, .max_generated_tokens = 1024,
.temperature = 1.0, .temperature = 1.0,
.verbosity = 0, .verbosity = 0,

View File

@ -1117,34 +1117,15 @@ HWY_NOINLINE void Transformer(
} }
template <class TConfig> template <class TConfig>
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, void RangeChecks(size_t& max_generated_tokens, const size_t prompt_size) {
size_t& prompt_size) {
if (!TConfig::kUseLocalAttention) { if (!TConfig::kUseLocalAttention) {
if (max_tokens > TConfig::kSeqLen) { if (max_generated_tokens > TConfig::kSeqLen) {
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
max_tokens, TConfig::kSeqLen);
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
}
}
if (max_generated_tokens > max_tokens) {
fprintf(stderr, fprintf(stderr,
"WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n", "WARNING: max_generated_tokens %zu > kSeqLen %d, truncating.\n",
max_generated_tokens, max_tokens); max_generated_tokens, TConfig::kSeqLen);
max_generated_tokens = max_tokens - 1; max_generated_tokens = static_cast<size_t>(TConfig::kSeqLen);
}
if (!TConfig::kUseLocalAttention) {
if (prompt_size + max_generated_tokens > max_tokens) {
fprintf(stderr,
"WARNING: prompt_size %zu + max_generated_tokens %zu > "
"max_tokens %zu, truncating to ",
prompt_size, max_generated_tokens, max_tokens);
prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens);
fprintf(stderr, "%zu\n", prompt_size);
} }
} }
HWY_ASSERT(prompt_size > 0); HWY_ASSERT(prompt_size > 0);
} }
@ -1262,17 +1243,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len)); const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
size_t max_prompt_size = MaxQueryLength(queries_prompt); size_t max_prompt_size = MaxQueryLength(queries_prompt);
size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size); RangeChecks<TConfig>(max_generated_tokens, max_prompt_size);
for (size_t pos : queries_pos_copy) {
if (pos >= max_tokens) {
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
max_tokens);
return;
}
}
const SampleFunc sample_token = ChooseSampleFunc<TConfig>(runtime_config); const SampleFunc sample_token = ChooseSampleFunc<TConfig>(runtime_config);
// Prefill stops before min_prompt_size - 1 because the last prompt // Prefill stops before min_prompt_size - 1 because the last prompt
@ -1315,7 +1287,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
} }
const double gen_start = hwy::platform::Now(); const double gen_start = hwy::platform::Now();
for (size_t gen = 0; gen < HWY_MIN(max_tokens, max_generated_tokens); ++gen) { for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
// Decode generates one token per query and increments queries_mutable_pos. // Decode generates one token per query and increments queries_mutable_pos.
Transformer<TConfig>( Transformer<TConfig>(
QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,

View File

@ -92,8 +92,7 @@ struct RuntimeConfig {
return stream_token(token, prob); return stream_token(token, prob);
} }
// Limits on the number of tokens generated. // Limit on the number of tokens generated.
size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
// These defaults are overridden by InferenceArgs::CopyTo(*this): // These defaults are overridden by InferenceArgs::CopyTo(*this):

View File

@ -131,7 +131,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
return true; return true;
}; };
while (abs_pos < args.max_tokens) { while (true) { // Loop until user quits.
tokens_generated_this_turn = 0; tokens_generated_this_turn = 0;
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line); std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
if (!std::cin) return; if (!std::cin) return;
@ -183,10 +183,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
timing_info); timing_info);
std::cout << "\n\n"; std::cout << "\n\n";
} }
std::cout
<< "max_tokens (" << args.max_tokens
<< ") exceeded. Use a larger value if desired using the --max_tokens "
<< "command line flag.\n";
} }
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {

View File

@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gemma/gemma.h"
#include <cstdio> #include <cstdio>
#include <memory> #include <memory>
#include <string> #include <string>
@ -22,6 +20,7 @@
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
@ -63,15 +62,13 @@ void PaliGemmaTest::InitVit(const std::string& path) {
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
Gemma& model = *(s_env->GetModel()); Gemma& model = *(s_env->GetModel());
s_env->MutableGen().seed(0x12345678); s_env->MutableGen().seed(0x12345678);
RuntimeConfig runtime_config = {.max_tokens = 1024, RuntimeConfig runtime_config = {.max_generated_tokens = 512,
.max_generated_tokens = 512,
.verbosity = 0, .verbosity = 0,
.gen = &s_env->MutableGen()}; .gen = &s_env->MutableGen()};
runtime_config.image_tokens = image_tokens_.get(); runtime_config.image_tokens = image_tokens_.get();
size_t abs_pos = 0; size_t abs_pos = 0;
std::string mutable_prompt = prompt_text; std::string mutable_prompt = prompt_text;
std::vector<int> tokens = std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
WrapAndTokenize(model.Tokenizer(), model.Info(), abs_pos, mutable_prompt);
std::string response; std::string response;
auto stream_token = [&](int token, float) { auto stream_token = [&](int token, float) {
std::string token_text; std::string token_text;

View File

@ -207,7 +207,6 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
InferenceArgs() { Init(); }; InferenceArgs() { Init(); };
size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
size_t prefill_tbatch_size; size_t prefill_tbatch_size;
@ -220,21 +219,15 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() const { const char* Validate() const {
if (max_tokens > gcpp::kSeqLen) { if (max_generated_tokens > gcpp::kSeqLen) {
return "max_tokens is larger than the maximum sequence length (see " return "max_generated_tokens is larger than the maximum sequence length "
"configs.h)."; "(see configs.h).";
}
if (max_generated_tokens > max_tokens) {
return "Maximum number of generated tokens is larger than the maximum "
"total tokens.";
} }
return nullptr; return nullptr;
} }
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
visitor(max_tokens, "max_tokens", size_t{3072},
"Maximum number of tokens in prompt + generation.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate."); "Maximum number of tokens to generate.");
@ -255,12 +248,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
} }
void CopyTo(RuntimeConfig& runtime_config) const { void CopyTo(RuntimeConfig& runtime_config) const {
runtime_config.max_tokens = max_tokens;
runtime_config.max_generated_tokens = max_generated_tokens; runtime_config.max_generated_tokens = max_generated_tokens;
runtime_config.prefill_tbatch_size = prefill_tbatch_size; runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size; runtime_config.decode_qbatch_size = decode_qbatch_size;
runtime_config.temperature = temperature; runtime_config.temperature = temperature;
} }
}; };