mirror of https://github.com/google/gemma.cpp.git
Introduce QueryResult in GemmaEnv and add a shortcut for WrapAndTokenize.
Remove max_tokens (and rely on only max_generated_tokens). PiperOrigin-RevId: 685662260
This commit is contained in:
parent
2892e232e2
commit
a4d6adbc43
|
|
@ -385,7 +385,6 @@ tokenizer : tokenizer.spm
|
|||
compressed_weights : 2b-it-sfp.sbs
|
||||
model : 2b-it
|
||||
weights : [no path specified]
|
||||
max_tokens : 3072
|
||||
max_generated_tokens : 2048
|
||||
|
||||
*Usage*
|
||||
|
|
|
|||
|
|
@ -69,7 +69,6 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
return token != ReverseSequenceSampler::kEndToken;
|
||||
};
|
||||
RuntimeConfig runtime = {
|
||||
.max_tokens = 32,
|
||||
.max_generated_tokens = 16,
|
||||
.temperature = 1.0f,
|
||||
.verbosity = 0,
|
||||
|
|
|
|||
|
|
@ -81,15 +81,15 @@ int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
|
|||
size_t total_tokens = 0;
|
||||
const double time_start = hwy::platform::Now();
|
||||
for (auto& [question, expected_answer] : queries_answers) {
|
||||
const auto [answer, token_count] = env.QueryModel(question);
|
||||
total_tokens += token_count;
|
||||
if (answer.find(expected_answer) != std::string::npos) {
|
||||
QueryResult result = env.QueryModel(question);
|
||||
total_tokens += result.tokens_generated;
|
||||
if (result.response.find(expected_answer) != std::string::npos) {
|
||||
correct_answers++;
|
||||
} else {
|
||||
std::cout << "Wrong!\n";
|
||||
std::cout << "Input: " << question << "\n";
|
||||
std::cout << "Expected: " << expected_answer << "\n";
|
||||
std::cout << "Output: " << answer << "\n\n" << std::flush;
|
||||
std::cout << "Output: " << result.response << "\n\n" << std::flush;
|
||||
}
|
||||
}
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
|
|
@ -108,9 +108,10 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
|
|||
prompt.append(ReadFileToString(text));
|
||||
prompt.append("\nSummarize this text.\n");
|
||||
const double time_start = hwy::platform::Now();
|
||||
const auto [answer, token_count] = env.QueryModel(prompt);
|
||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||
LogSpeedStats(time_start, token_count);
|
||||
QueryResult result = env.QueryModel(prompt);
|
||||
std::cout << result.response.substr(result.response_start_pos) << "\n"
|
||||
<< std::flush;
|
||||
LogSpeedStats(time_start, result.tokens_generated);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
|
|
@ -118,7 +119,6 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
|||
size_t batch_tokens) {
|
||||
std::string input = ReadFileToString(text);
|
||||
std::vector<int> prompt = env.Tokenize(input);
|
||||
prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
|
||||
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
||||
const double time_start = hwy::platform::Now();
|
||||
float total_entropy = 0.0f;
|
||||
|
|
@ -156,11 +156,11 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
|
|||
while (std::getline(trivia_file, line)) {
|
||||
json data = json::parse(line);
|
||||
std::string q(data["question"]);
|
||||
const auto [answer, token_count] = env.QueryModel(q);
|
||||
std::cout << answer << "\n";
|
||||
QueryResult result = env.QueryModel(q);
|
||||
std::cout << result.response << "\n";
|
||||
bool correct = false;
|
||||
for (const std::string expected : data["answer"]["aliases"]) {
|
||||
if (answer.find(expected) != std::string::npos) {
|
||||
if (result.response.find(expected) != std::string::npos) {
|
||||
correct = true;
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,14 +18,12 @@
|
|||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <utility> // std::pair
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
|
|
@ -76,7 +74,6 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
}
|
||||
InitGenerator(inference, gen_);
|
||||
runtime_config_ = {
|
||||
.max_tokens = inference.max_tokens,
|
||||
.max_generated_tokens = inference.max_generated_tokens,
|
||||
.temperature = inference.temperature,
|
||||
.verbosity = app.verbosity,
|
||||
|
|
@ -99,21 +96,21 @@ GemmaEnv::GemmaEnv(int argc, char** argv)
|
|||
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
||||
MakeAppArgs(argc, argv)) {}
|
||||
|
||||
std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||
const std::vector<int>& tokens) {
|
||||
std::string res;
|
||||
size_t total_tokens = 0;
|
||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||
QueryResult result;
|
||||
|
||||
const BatchStreamFunc batch_stream_token = [&res, &total_tokens, this](
|
||||
size_t query_index, size_t pos,
|
||||
int token, float) {
|
||||
++total_tokens;
|
||||
res += StringFromTokens(std::vector<int>{token});
|
||||
const BatchStreamFunc batch_stream_token =
|
||||
[&result, &tokens, this](size_t /*query_index*/, size_t /*pos*/,
|
||||
int token, float /*score*/) {
|
||||
++result.tokens_generated;
|
||||
result.response += StringFromTokens(std::vector<int>{token});
|
||||
if (result.tokens_generated == tokens.size()) {
|
||||
result.response_start_pos = result.response.size();
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (runtime_config_.verbosity >= 2) {
|
||||
std::cout << "Max tokens: " << runtime_config_.max_tokens
|
||||
<< "\tmax generated tokens: "
|
||||
std::cout << "max generated tokens: "
|
||||
<< runtime_config_.max_generated_tokens
|
||||
<< "\ttemperature: " << runtime_config_.temperature << "\n";
|
||||
}
|
||||
|
|
@ -121,7 +118,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
return {res, total_tokens};
|
||||
return result;
|
||||
}
|
||||
|
||||
void GemmaEnv::QueryModel(
|
||||
|
|
@ -134,27 +131,29 @@ void GemmaEnv::QueryModel(
|
|||
runtime_config_.stream_token = previous_stream_token;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||
const QueriesPromptTokens& queries_prompt) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(num_queries != 0);
|
||||
std::vector<std::pair<std::string, size_t>> res(num_queries);
|
||||
std::fill(res.begin(), res.end(), std::make_pair("", 0));
|
||||
const BatchStreamFunc batch_stream_token = [&res, this](size_t query_index,
|
||||
size_t pos, int token,
|
||||
float) {
|
||||
std::vector<QueryResult> res(num_queries);
|
||||
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
|
||||
size_t query_index, size_t pos,
|
||||
int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
res[query_index].first.append(token_text);
|
||||
res[query_index].second += 1;
|
||||
res[query_index].response.append(token_text);
|
||||
res[query_index].tokens_generated += 1;
|
||||
if (res[query_index].tokens_generated ==
|
||||
queries_prompt[query_index].size()) {
|
||||
res[query_index].response_start_pos = res[query_index].response.size();
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (runtime_config_.verbosity >= 2) {
|
||||
fprintf(stderr,
|
||||
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||
runtime_config_.max_tokens, runtime_config_.max_generated_tokens,
|
||||
runtime_config_.temperature, runtime_config_.prefill_tbatch_size,
|
||||
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
||||
runtime_config_.prefill_tbatch_size,
|
||||
runtime_config_.decode_qbatch_size);
|
||||
}
|
||||
|
||||
|
|
@ -178,21 +177,18 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
|||
return res;
|
||||
}
|
||||
|
||||
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
|
||||
const std::vector<int> prompt =
|
||||
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
||||
/*pos=*/0, input);
|
||||
QueryResult GemmaEnv::QueryModel(std::string& input) {
|
||||
const std::vector<int> prompt = WrapAndTokenize(input);
|
||||
return QueryModel(prompt);
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||
const std::vector<std::string>& inputs) {
|
||||
std::vector<std::vector<int>> prompts;
|
||||
prompts.reserve(inputs.size());
|
||||
for (auto& input : inputs) {
|
||||
std::string mutable_prompt = input;
|
||||
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
||||
/*pos=*/0, mutable_prompt));
|
||||
prompts.push_back(WrapAndTokenize(mutable_prompt));
|
||||
}
|
||||
std::vector<PromptTokens> prompt_vector;
|
||||
prompt_vector.reserve(prompts.size());
|
||||
|
|
@ -206,7 +202,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
|||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||
std::vector<int> prompt = Tokenize(input);
|
||||
prompt.insert(prompt.begin(), BOS_ID);
|
||||
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt,
|
||||
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
|
||||
MutableKVCache(),
|
||||
/*verbosity=*/0) /
|
||||
static_cast<int>(input.size());
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@
|
|||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
|
|
@ -33,6 +32,14 @@ namespace gcpp {
|
|||
|
||||
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
|
||||
|
||||
// Return type for query model calls.
|
||||
struct QueryResult {
|
||||
std::string response;
|
||||
size_t tokens_generated = 0;
|
||||
// The position in the response at which the generated tokens start.
|
||||
size_t response_start_pos = 0;
|
||||
};
|
||||
|
||||
// Convenience class to load a model and run inference.
|
||||
class GemmaEnv {
|
||||
public:
|
||||
|
|
@ -41,8 +48,9 @@ class GemmaEnv {
|
|||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
const AppArgs& app);
|
||||
|
||||
size_t MaxTokens() const { return runtime_config_.max_tokens; }
|
||||
// Sets the maximum number of output tokens to generate.
|
||||
size_t MaxGeneratedTokens() const {
|
||||
return runtime_config_.max_generated_tokens;
|
||||
}
|
||||
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
|
||||
runtime_config_.max_generated_tokens = max_generated_tokens;
|
||||
}
|
||||
|
|
@ -59,6 +67,10 @@ class GemmaEnv {
|
|||
return tokens;
|
||||
}
|
||||
|
||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
|
||||
}
|
||||
|
||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||
std::string string;
|
||||
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
|
||||
|
|
@ -67,12 +79,12 @@ class GemmaEnv {
|
|||
|
||||
// Runs inference on the given input and returns the top-1 result string and
|
||||
// the number of tokens that were generated.
|
||||
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
||||
QueryResult QueryModel(const std::vector<int>& tokens);
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
const QueriesPromptTokens& queries_prompt);
|
||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||
std::pair<std::string, size_t> QueryModel(std::string& input);
|
||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
||||
QueryResult QueryModel(std::string& input);
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
const std::vector<std::string>& inputs);
|
||||
|
||||
// Runs inference on the given input and calls the callback for each token.
|
||||
|
|
|
|||
|
|
@ -33,11 +33,11 @@ void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
|
|||
size_t total_tokens = 0;
|
||||
for (auto s : state) {
|
||||
std::string prompt = original_prompt; // reset from original
|
||||
auto [response, n] = s_env->QueryModel(prompt);
|
||||
QueryResult result = s_env->QueryModel(prompt);
|
||||
if (s_env->Verbosity() != 0) {
|
||||
fprintf(stdout, "|%s|\n", response.c_str());
|
||||
fprintf(stdout, "|%s|\n", result.response.c_str());
|
||||
}
|
||||
total_tokens += n;
|
||||
total_tokens += result.tokens_generated;
|
||||
}
|
||||
|
||||
state.SetItemsProcessed(total_tokens);
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ namespace gcpp {
|
|||
|
||||
HWY_EXPORT(CallSoftmax);
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
int verbosity) {
|
||||
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
|
||||
|
|
@ -112,8 +112,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
|||
size_t vocab_size) -> TokenAndProb {
|
||||
// input is logits, not yet probabilities
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
|
||||
// We are called for each token, but pos starts at 1. Clamping max_tokens
|
||||
// to prompt.size() should prevent overrun.
|
||||
// We are called for each token, but pos starts at 1. Clamping
|
||||
// max_generated_tokens to prompt.size() should prevent overrun.
|
||||
HWY_ASSERT(pos < prompt.size());
|
||||
const int token = prompt[pos];
|
||||
const float prob = probs[token];
|
||||
|
|
@ -136,10 +136,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
|||
};
|
||||
|
||||
std::vector<int> prompt0 = { prompt[0] };
|
||||
max_tokens = HWY_MIN(max_tokens, prompt.size());
|
||||
max_generated_tokens = HWY_MIN(max_generated_tokens, prompt.size());
|
||||
RuntimeConfig runtime = {
|
||||
.max_tokens = max_tokens,
|
||||
.max_generated_tokens = max_tokens - 1,
|
||||
.max_generated_tokens = max_generated_tokens - 1,
|
||||
.temperature = 0.0f,
|
||||
.verbosity = verbosity,
|
||||
.gen = nullptr,
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
int verbosity);
|
||||
|
||||
|
|
|
|||
|
|
@ -68,8 +68,9 @@ int Run(int argc, char** argv) {
|
|||
json_base[std::to_string(pos)][debug_key] = v;
|
||||
};
|
||||
|
||||
const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
|
||||
std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush;
|
||||
QueryResult result = env.QueryModel(prompt_args.prompt);
|
||||
std::cout << result.response.substr(result.response_start_pos) << "\n"
|
||||
<< std::flush;
|
||||
|
||||
if (env.MutableConfig().layers_output) {
|
||||
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
|
||||
|
|
|
|||
|
|
@ -52,13 +52,13 @@ class GemmaTest : public ::testing::Test {
|
|||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||
std::string mutable_prompt = prompt;
|
||||
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||
return response;
|
||||
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||
return result.response;
|
||||
}
|
||||
// Otherwise, do not use turn structure.
|
||||
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
||||
auto [response, n] = s_env->QueryModel(tokens);
|
||||
return response;
|
||||
QueryResult result = s_env->QueryModel(tokens);
|
||||
return result.response;
|
||||
}
|
||||
|
||||
std::vector<std::string> BatchGemmaReply(
|
||||
|
|
@ -72,8 +72,8 @@ class GemmaTest : public ::testing::Test {
|
|||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(response);
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
|
|
@ -88,8 +88,8 @@ class GemmaTest : public ::testing::Test {
|
|||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||
for (auto [response, n] : s_env->BatchQueryModel(prompts)) {
|
||||
replies.push_back(response);
|
||||
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
|
|
@ -167,7 +167,6 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
return true;
|
||||
};
|
||||
RuntimeConfig runtime_config{
|
||||
.max_tokens = 128,
|
||||
.max_generated_tokens = 64,
|
||||
.temperature = 0.0f,
|
||||
.verbosity = 2,
|
||||
|
|
|
|||
|
|
@ -103,9 +103,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
"What is start of the line with the correct answer? "
|
||||
"Do not include any justifications or explanations. Reply only with a "
|
||||
"letter.";
|
||||
const std::vector<int> prompt =
|
||||
WrapAndTokenize(env.GetModel()->Tokenizer(), env.GetModel()->Info(),
|
||||
/*pos=*/0, prompt_string);
|
||||
const std::vector<int> prompt = env.WrapAndTokenize(prompt_string);
|
||||
const size_t prompt_size = prompt.size();
|
||||
|
||||
std::vector<int> predicted_token_ids;
|
||||
|
|
@ -127,7 +125,6 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
// confused with the word "A".
|
||||
gcpp::TimingInfo timing_info;
|
||||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_tokens = env.MaxTokens(),
|
||||
.max_generated_tokens = 30,
|
||||
.temperature = 0.0f,
|
||||
.verbosity = env.Verbosity(),
|
||||
|
|
|
|||
|
|
@ -87,7 +87,6 @@ int main(int argc, char** argv) {
|
|||
|
||||
gcpp::TimingInfo timing_info;
|
||||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_tokens = 1536,
|
||||
.max_generated_tokens = 1024,
|
||||
.temperature = 1.0,
|
||||
.verbosity = 0,
|
||||
|
|
|
|||
|
|
@ -1117,34 +1117,15 @@ HWY_NOINLINE void Transformer(
|
|||
}
|
||||
|
||||
template <class TConfig>
|
||||
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||
size_t& prompt_size) {
|
||||
void RangeChecks(size_t& max_generated_tokens, const size_t prompt_size) {
|
||||
if (!TConfig::kUseLocalAttention) {
|
||||
if (max_tokens > TConfig::kSeqLen) {
|
||||
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
|
||||
max_tokens, TConfig::kSeqLen);
|
||||
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
|
||||
}
|
||||
}
|
||||
|
||||
if (max_generated_tokens > max_tokens) {
|
||||
if (max_generated_tokens > TConfig::kSeqLen) {
|
||||
fprintf(stderr,
|
||||
"WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n",
|
||||
max_generated_tokens, max_tokens);
|
||||
max_generated_tokens = max_tokens - 1;
|
||||
}
|
||||
|
||||
if (!TConfig::kUseLocalAttention) {
|
||||
if (prompt_size + max_generated_tokens > max_tokens) {
|
||||
fprintf(stderr,
|
||||
"WARNING: prompt_size %zu + max_generated_tokens %zu > "
|
||||
"max_tokens %zu, truncating to ",
|
||||
prompt_size, max_generated_tokens, max_tokens);
|
||||
prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens);
|
||||
fprintf(stderr, "%zu\n", prompt_size);
|
||||
"WARNING: max_generated_tokens %zu > kSeqLen %d, truncating.\n",
|
||||
max_generated_tokens, TConfig::kSeqLen);
|
||||
max_generated_tokens = static_cast<size_t>(TConfig::kSeqLen);
|
||||
}
|
||||
}
|
||||
|
||||
HWY_ASSERT(prompt_size > 0);
|
||||
}
|
||||
|
||||
|
|
@ -1262,17 +1243,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||
|
||||
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
||||
size_t max_tokens = runtime_config.max_tokens;
|
||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
|
||||
for (size_t pos : queries_pos_copy) {
|
||||
if (pos >= max_tokens) {
|
||||
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
|
||||
max_tokens);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
RangeChecks<TConfig>(max_generated_tokens, max_prompt_size);
|
||||
const SampleFunc sample_token = ChooseSampleFunc<TConfig>(runtime_config);
|
||||
|
||||
// Prefill stops before min_prompt_size - 1 because the last prompt
|
||||
|
|
@ -1315,7 +1287,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
}
|
||||
|
||||
const double gen_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < HWY_MIN(max_tokens, max_generated_tokens); ++gen) {
|
||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||
// Decode generates one token per query and increments queries_mutable_pos.
|
||||
Transformer<TConfig>(
|
||||
QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
|
||||
|
|
|
|||
|
|
@ -92,8 +92,7 @@ struct RuntimeConfig {
|
|||
return stream_token(token, prob);
|
||||
}
|
||||
|
||||
// Limits on the number of tokens generated.
|
||||
size_t max_tokens;
|
||||
// Limit on the number of tokens generated.
|
||||
size_t max_generated_tokens;
|
||||
|
||||
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
return true;
|
||||
};
|
||||
|
||||
while (abs_pos < args.max_tokens) {
|
||||
while (true) { // Loop until user quits.
|
||||
tokens_generated_this_turn = 0;
|
||||
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
|
||||
if (!std::cin) return;
|
||||
|
|
@ -183,10 +183,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
std::cout
|
||||
<< "max_tokens (" << args.max_tokens
|
||||
<< ") exceeded. Use a larger value if desired using the --max_tokens "
|
||||
<< "command line flag.\n";
|
||||
}
|
||||
|
||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
|
@ -22,6 +20,7 @@
|
|||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
|
|
@ -63,15 +62,13 @@ void PaliGemmaTest::InitVit(const std::string& path) {
|
|||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||
Gemma& model = *(s_env->GetModel());
|
||||
s_env->MutableGen().seed(0x12345678);
|
||||
RuntimeConfig runtime_config = {.max_tokens = 1024,
|
||||
.max_generated_tokens = 512,
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
.verbosity = 0,
|
||||
.gen = &s_env->MutableGen()};
|
||||
runtime_config.image_tokens = image_tokens_.get();
|
||||
size_t abs_pos = 0;
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens =
|
||||
WrapAndTokenize(model.Tokenizer(), model.Info(), abs_pos, mutable_prompt);
|
||||
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
std::string token_text;
|
||||
|
|
|
|||
16
util/app.h
16
util/app.h
|
|
@ -207,7 +207,6 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
InferenceArgs() { Init(); };
|
||||
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
||||
size_t prefill_tbatch_size;
|
||||
|
|
@ -220,21 +219,15 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
if (max_tokens > gcpp::kSeqLen) {
|
||||
return "max_tokens is larger than the maximum sequence length (see "
|
||||
"configs.h).";
|
||||
}
|
||||
if (max_generated_tokens > max_tokens) {
|
||||
return "Maximum number of generated tokens is larger than the maximum "
|
||||
"total tokens.";
|
||||
if (max_generated_tokens > gcpp::kSeqLen) {
|
||||
return "max_generated_tokens is larger than the maximum sequence length "
|
||||
"(see configs.h).";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(max_tokens, "max_tokens", size_t{3072},
|
||||
"Maximum number of tokens in prompt + generation.");
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
|
|
@ -255,12 +248,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
}
|
||||
|
||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||
runtime_config.max_tokens = max_tokens;
|
||||
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||
|
||||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||
|
||||
runtime_config.temperature = temperature;
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue