mirror of https://github.com/google/gemma.cpp.git
Introduce QueryResult in GemmaEnv and add a shortcut for WrapAndTokenize.
Remove max_tokens (and rely on only max_generated_tokens). PiperOrigin-RevId: 685662260
This commit is contained in:
parent
2892e232e2
commit
a4d6adbc43
|
|
@ -385,7 +385,6 @@ tokenizer : tokenizer.spm
|
||||||
compressed_weights : 2b-it-sfp.sbs
|
compressed_weights : 2b-it-sfp.sbs
|
||||||
model : 2b-it
|
model : 2b-it
|
||||||
weights : [no path specified]
|
weights : [no path specified]
|
||||||
max_tokens : 3072
|
|
||||||
max_generated_tokens : 2048
|
max_generated_tokens : 2048
|
||||||
|
|
||||||
*Usage*
|
*Usage*
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,6 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
return token != ReverseSequenceSampler::kEndToken;
|
return token != ReverseSequenceSampler::kEndToken;
|
||||||
};
|
};
|
||||||
RuntimeConfig runtime = {
|
RuntimeConfig runtime = {
|
||||||
.max_tokens = 32,
|
|
||||||
.max_generated_tokens = 16,
|
.max_generated_tokens = 16,
|
||||||
.temperature = 1.0f,
|
.temperature = 1.0f,
|
||||||
.verbosity = 0,
|
.verbosity = 0,
|
||||||
|
|
|
||||||
|
|
@ -81,15 +81,15 @@ int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
|
||||||
size_t total_tokens = 0;
|
size_t total_tokens = 0;
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
for (auto& [question, expected_answer] : queries_answers) {
|
for (auto& [question, expected_answer] : queries_answers) {
|
||||||
const auto [answer, token_count] = env.QueryModel(question);
|
QueryResult result = env.QueryModel(question);
|
||||||
total_tokens += token_count;
|
total_tokens += result.tokens_generated;
|
||||||
if (answer.find(expected_answer) != std::string::npos) {
|
if (result.response.find(expected_answer) != std::string::npos) {
|
||||||
correct_answers++;
|
correct_answers++;
|
||||||
} else {
|
} else {
|
||||||
std::cout << "Wrong!\n";
|
std::cout << "Wrong!\n";
|
||||||
std::cout << "Input: " << question << "\n";
|
std::cout << "Input: " << question << "\n";
|
||||||
std::cout << "Expected: " << expected_answer << "\n";
|
std::cout << "Expected: " << expected_answer << "\n";
|
||||||
std::cout << "Output: " << answer << "\n\n" << std::flush;
|
std::cout << "Output: " << result.response << "\n\n" << std::flush;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LogSpeedStats(time_start, total_tokens);
|
LogSpeedStats(time_start, total_tokens);
|
||||||
|
|
@ -108,9 +108,10 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
|
||||||
prompt.append(ReadFileToString(text));
|
prompt.append(ReadFileToString(text));
|
||||||
prompt.append("\nSummarize this text.\n");
|
prompt.append("\nSummarize this text.\n");
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
const auto [answer, token_count] = env.QueryModel(prompt);
|
QueryResult result = env.QueryModel(prompt);
|
||||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
std::cout << result.response.substr(result.response_start_pos) << "\n"
|
||||||
LogSpeedStats(time_start, token_count);
|
<< std::flush;
|
||||||
|
LogSpeedStats(time_start, result.tokens_generated);
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -118,7 +119,6 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
size_t batch_tokens) {
|
size_t batch_tokens) {
|
||||||
std::string input = ReadFileToString(text);
|
std::string input = ReadFileToString(text);
|
||||||
std::vector<int> prompt = env.Tokenize(input);
|
std::vector<int> prompt = env.Tokenize(input);
|
||||||
prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
|
|
||||||
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
float total_entropy = 0.0f;
|
float total_entropy = 0.0f;
|
||||||
|
|
@ -156,11 +156,11 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
|
||||||
while (std::getline(trivia_file, line)) {
|
while (std::getline(trivia_file, line)) {
|
||||||
json data = json::parse(line);
|
json data = json::parse(line);
|
||||||
std::string q(data["question"]);
|
std::string q(data["question"]);
|
||||||
const auto [answer, token_count] = env.QueryModel(q);
|
QueryResult result = env.QueryModel(q);
|
||||||
std::cout << answer << "\n";
|
std::cout << result.response << "\n";
|
||||||
bool correct = false;
|
bool correct = false;
|
||||||
for (const std::string expected : data["answer"]["aliases"]) {
|
for (const std::string expected : data["answer"]["aliases"]) {
|
||||||
if (answer.find(expected) != std::string::npos) {
|
if (result.response.find(expected) != std::string::npos) {
|
||||||
correct = true;
|
correct = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,12 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility> // std::pair
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
// Placeholder for internal header, do not modify.
|
||||||
|
|
@ -76,7 +74,6 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
}
|
}
|
||||||
InitGenerator(inference, gen_);
|
InitGenerator(inference, gen_);
|
||||||
runtime_config_ = {
|
runtime_config_ = {
|
||||||
.max_tokens = inference.max_tokens,
|
|
||||||
.max_generated_tokens = inference.max_generated_tokens,
|
.max_generated_tokens = inference.max_generated_tokens,
|
||||||
.temperature = inference.temperature,
|
.temperature = inference.temperature,
|
||||||
.verbosity = app.verbosity,
|
.verbosity = app.verbosity,
|
||||||
|
|
@ -99,21 +96,21 @@ GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||||
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
||||||
MakeAppArgs(argc, argv)) {}
|
MakeAppArgs(argc, argv)) {}
|
||||||
|
|
||||||
std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||||
const std::vector<int>& tokens) {
|
QueryResult result;
|
||||||
std::string res;
|
|
||||||
size_t total_tokens = 0;
|
|
||||||
|
|
||||||
const BatchStreamFunc batch_stream_token = [&res, &total_tokens, this](
|
const BatchStreamFunc batch_stream_token =
|
||||||
size_t query_index, size_t pos,
|
[&result, &tokens, this](size_t /*query_index*/, size_t /*pos*/,
|
||||||
int token, float) {
|
int token, float /*score*/) {
|
||||||
++total_tokens;
|
++result.tokens_generated;
|
||||||
res += StringFromTokens(std::vector<int>{token});
|
result.response += StringFromTokens(std::vector<int>{token});
|
||||||
return true;
|
if (result.tokens_generated == tokens.size()) {
|
||||||
};
|
result.response_start_pos = result.response.size();
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
if (runtime_config_.verbosity >= 2) {
|
if (runtime_config_.verbosity >= 2) {
|
||||||
std::cout << "Max tokens: " << runtime_config_.max_tokens
|
std::cout << "max generated tokens: "
|
||||||
<< "\tmax generated tokens: "
|
|
||||||
<< runtime_config_.max_generated_tokens
|
<< runtime_config_.max_generated_tokens
|
||||||
<< "\ttemperature: " << runtime_config_.temperature << "\n";
|
<< "\ttemperature: " << runtime_config_.temperature << "\n";
|
||||||
}
|
}
|
||||||
|
|
@ -121,7 +118,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||||
timing_info);
|
timing_info);
|
||||||
return {res, total_tokens};
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GemmaEnv::QueryModel(
|
void GemmaEnv::QueryModel(
|
||||||
|
|
@ -134,27 +131,29 @@ void GemmaEnv::QueryModel(
|
||||||
runtime_config_.stream_token = previous_stream_token;
|
runtime_config_.stream_token = previous_stream_token;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
const QueriesPromptTokens& queries_prompt) {
|
const QueriesPromptTokens& queries_prompt) {
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
HWY_ASSERT(num_queries != 0);
|
HWY_ASSERT(num_queries != 0);
|
||||||
std::vector<std::pair<std::string, size_t>> res(num_queries);
|
std::vector<QueryResult> res(num_queries);
|
||||||
std::fill(res.begin(), res.end(), std::make_pair("", 0));
|
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
|
||||||
const BatchStreamFunc batch_stream_token = [&res, this](size_t query_index,
|
size_t query_index, size_t pos,
|
||||||
size_t pos, int token,
|
int token, float) {
|
||||||
float) {
|
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(
|
HWY_ASSERT(
|
||||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
res[query_index].first.append(token_text);
|
res[query_index].response.append(token_text);
|
||||||
res[query_index].second += 1;
|
res[query_index].tokens_generated += 1;
|
||||||
|
if (res[query_index].tokens_generated ==
|
||||||
|
queries_prompt[query_index].size()) {
|
||||||
|
res[query_index].response_start_pos = res[query_index].response.size();
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
if (runtime_config_.verbosity >= 2) {
|
if (runtime_config_.verbosity >= 2) {
|
||||||
fprintf(stderr,
|
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||||
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
||||||
runtime_config_.max_tokens, runtime_config_.max_generated_tokens,
|
runtime_config_.prefill_tbatch_size,
|
||||||
runtime_config_.temperature, runtime_config_.prefill_tbatch_size,
|
|
||||||
runtime_config_.decode_qbatch_size);
|
runtime_config_.decode_qbatch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -178,21 +177,18 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
|
QueryResult GemmaEnv::QueryModel(std::string& input) {
|
||||||
const std::vector<int> prompt =
|
const std::vector<int> prompt = WrapAndTokenize(input);
|
||||||
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
|
||||||
/*pos=*/0, input);
|
|
||||||
return QueryModel(prompt);
|
return QueryModel(prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
const std::vector<std::string>& inputs) {
|
const std::vector<std::string>& inputs) {
|
||||||
std::vector<std::vector<int>> prompts;
|
std::vector<std::vector<int>> prompts;
|
||||||
prompts.reserve(inputs.size());
|
prompts.reserve(inputs.size());
|
||||||
for (auto& input : inputs) {
|
for (auto& input : inputs) {
|
||||||
std::string mutable_prompt = input;
|
std::string mutable_prompt = input;
|
||||||
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
prompts.push_back(WrapAndTokenize(mutable_prompt));
|
||||||
/*pos=*/0, mutable_prompt));
|
|
||||||
}
|
}
|
||||||
std::vector<PromptTokens> prompt_vector;
|
std::vector<PromptTokens> prompt_vector;
|
||||||
prompt_vector.reserve(prompts.size());
|
prompt_vector.reserve(prompts.size());
|
||||||
|
|
@ -206,7 +202,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||||
std::vector<int> prompt = Tokenize(input);
|
std::vector<int> prompt = Tokenize(input);
|
||||||
prompt.insert(prompt.begin(), BOS_ID);
|
prompt.insert(prompt.begin(), BOS_ID);
|
||||||
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt,
|
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
|
||||||
MutableKVCache(),
|
MutableKVCache(),
|
||||||
/*verbosity=*/0) /
|
/*verbosity=*/0) /
|
||||||
static_cast<int>(input.size());
|
static_cast<int>(input.size());
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,6 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
|
@ -33,6 +32,14 @@ namespace gcpp {
|
||||||
|
|
||||||
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
|
void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);
|
||||||
|
|
||||||
|
// Return type for query model calls.
|
||||||
|
struct QueryResult {
|
||||||
|
std::string response;
|
||||||
|
size_t tokens_generated = 0;
|
||||||
|
// The position in the response at which the generated tokens start.
|
||||||
|
size_t response_start_pos = 0;
|
||||||
|
};
|
||||||
|
|
||||||
// Convenience class to load a model and run inference.
|
// Convenience class to load a model and run inference.
|
||||||
class GemmaEnv {
|
class GemmaEnv {
|
||||||
public:
|
public:
|
||||||
|
|
@ -41,8 +48,9 @@ class GemmaEnv {
|
||||||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
const AppArgs& app);
|
const AppArgs& app);
|
||||||
|
|
||||||
size_t MaxTokens() const { return runtime_config_.max_tokens; }
|
size_t MaxGeneratedTokens() const {
|
||||||
// Sets the maximum number of output tokens to generate.
|
return runtime_config_.max_generated_tokens;
|
||||||
|
}
|
||||||
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
|
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
|
||||||
runtime_config_.max_generated_tokens = max_generated_tokens;
|
runtime_config_.max_generated_tokens = max_generated_tokens;
|
||||||
}
|
}
|
||||||
|
|
@ -59,6 +67,10 @@ class GemmaEnv {
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||||
|
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
|
||||||
|
}
|
||||||
|
|
||||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||||
std::string string;
|
std::string string;
|
||||||
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
|
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
|
||||||
|
|
@ -67,12 +79,12 @@ class GemmaEnv {
|
||||||
|
|
||||||
// Runs inference on the given input and returns the top-1 result string and
|
// Runs inference on the given input and returns the top-1 result string and
|
||||||
// the number of tokens that were generated.
|
// the number of tokens that were generated.
|
||||||
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
QueryResult QueryModel(const std::vector<int>& tokens);
|
||||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
std::vector<QueryResult> BatchQueryModel(
|
||||||
const QueriesPromptTokens& queries_prompt);
|
const QueriesPromptTokens& queries_prompt);
|
||||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||||
std::pair<std::string, size_t> QueryModel(std::string& input);
|
QueryResult QueryModel(std::string& input);
|
||||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
std::vector<QueryResult> BatchQueryModel(
|
||||||
const std::vector<std::string>& inputs);
|
const std::vector<std::string>& inputs);
|
||||||
|
|
||||||
// Runs inference on the given input and calls the callback for each token.
|
// Runs inference on the given input and calls the callback for each token.
|
||||||
|
|
|
||||||
|
|
@ -33,11 +33,11 @@ void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
|
||||||
size_t total_tokens = 0;
|
size_t total_tokens = 0;
|
||||||
for (auto s : state) {
|
for (auto s : state) {
|
||||||
std::string prompt = original_prompt; // reset from original
|
std::string prompt = original_prompt; // reset from original
|
||||||
auto [response, n] = s_env->QueryModel(prompt);
|
QueryResult result = s_env->QueryModel(prompt);
|
||||||
if (s_env->Verbosity() != 0) {
|
if (s_env->Verbosity() != 0) {
|
||||||
fprintf(stdout, "|%s|\n", response.c_str());
|
fprintf(stdout, "|%s|\n", result.response.c_str());
|
||||||
}
|
}
|
||||||
total_tokens += n;
|
total_tokens += result.tokens_generated;
|
||||||
}
|
}
|
||||||
|
|
||||||
state.SetItemsProcessed(total_tokens);
|
state.SetItemsProcessed(total_tokens);
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ namespace gcpp {
|
||||||
|
|
||||||
HWY_EXPORT(CallSoftmax);
|
HWY_EXPORT(CallSoftmax);
|
||||||
|
|
||||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
int verbosity) {
|
int verbosity) {
|
||||||
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
|
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
|
||||||
|
|
@ -112,8 +112,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
size_t vocab_size) -> TokenAndProb {
|
size_t vocab_size) -> TokenAndProb {
|
||||||
// input is logits, not yet probabilities
|
// input is logits, not yet probabilities
|
||||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
|
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
|
||||||
// We are called for each token, but pos starts at 1. Clamping max_tokens
|
// We are called for each token, but pos starts at 1. Clamping
|
||||||
// to prompt.size() should prevent overrun.
|
// max_generated_tokens to prompt.size() should prevent overrun.
|
||||||
HWY_ASSERT(pos < prompt.size());
|
HWY_ASSERT(pos < prompt.size());
|
||||||
const int token = prompt[pos];
|
const int token = prompt[pos];
|
||||||
const float prob = probs[token];
|
const float prob = probs[token];
|
||||||
|
|
@ -136,10 +136,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<int> prompt0 = { prompt[0] };
|
std::vector<int> prompt0 = { prompt[0] };
|
||||||
max_tokens = HWY_MIN(max_tokens, prompt.size());
|
max_generated_tokens = HWY_MIN(max_generated_tokens, prompt.size());
|
||||||
RuntimeConfig runtime = {
|
RuntimeConfig runtime = {
|
||||||
.max_tokens = max_tokens,
|
.max_generated_tokens = max_generated_tokens - 1,
|
||||||
.max_generated_tokens = max_tokens - 1,
|
|
||||||
.temperature = 0.0f,
|
.temperature = 0.0f,
|
||||||
.verbosity = verbosity,
|
.verbosity = verbosity,
|
||||||
.gen = nullptr,
|
.gen = nullptr,
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
int verbosity);
|
int verbosity);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,8 +68,9 @@ int Run(int argc, char** argv) {
|
||||||
json_base[std::to_string(pos)][debug_key] = v;
|
json_base[std::to_string(pos)][debug_key] = v;
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
|
QueryResult result = env.QueryModel(prompt_args.prompt);
|
||||||
std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush;
|
std::cout << result.response.substr(result.response_start_pos) << "\n"
|
||||||
|
<< std::flush;
|
||||||
|
|
||||||
if (env.MutableConfig().layers_output) {
|
if (env.MutableConfig().layers_output) {
|
||||||
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
|
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
|
||||||
|
|
|
||||||
|
|
@ -52,13 +52,13 @@ class GemmaTest : public ::testing::Test {
|
||||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||||
std::string mutable_prompt = prompt;
|
std::string mutable_prompt = prompt;
|
||||||
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
|
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||||
return response;
|
return result.response;
|
||||||
}
|
}
|
||||||
// Otherwise, do not use turn structure.
|
// Otherwise, do not use turn structure.
|
||||||
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
||||||
auto [response, n] = s_env->QueryModel(tokens);
|
QueryResult result = s_env->QueryModel(tokens);
|
||||||
return response;
|
return result.response;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> BatchGemmaReply(
|
std::vector<std::string> BatchGemmaReply(
|
||||||
|
|
@ -72,8 +72,8 @@ class GemmaTest : public ::testing::Test {
|
||||||
// It would be good to make these tests more consistent.
|
// It would be good to make these tests more consistent.
|
||||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||||
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
|
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||||
replies.push_back(response);
|
replies.push_back(result.response);
|
||||||
}
|
}
|
||||||
return replies;
|
return replies;
|
||||||
}
|
}
|
||||||
|
|
@ -88,8 +88,8 @@ class GemmaTest : public ::testing::Test {
|
||||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||||
}
|
}
|
||||||
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||||
for (auto [response, n] : s_env->BatchQueryModel(prompts)) {
|
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
|
||||||
replies.push_back(response);
|
replies.push_back(result.response);
|
||||||
}
|
}
|
||||||
return replies;
|
return replies;
|
||||||
}
|
}
|
||||||
|
|
@ -167,7 +167,6 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
RuntimeConfig runtime_config{
|
RuntimeConfig runtime_config{
|
||||||
.max_tokens = 128,
|
|
||||||
.max_generated_tokens = 64,
|
.max_generated_tokens = 64,
|
||||||
.temperature = 0.0f,
|
.temperature = 0.0f,
|
||||||
.verbosity = 2,
|
.verbosity = 2,
|
||||||
|
|
|
||||||
|
|
@ -103,9 +103,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
||||||
"What is start of the line with the correct answer? "
|
"What is start of the line with the correct answer? "
|
||||||
"Do not include any justifications or explanations. Reply only with a "
|
"Do not include any justifications or explanations. Reply only with a "
|
||||||
"letter.";
|
"letter.";
|
||||||
const std::vector<int> prompt =
|
const std::vector<int> prompt = env.WrapAndTokenize(prompt_string);
|
||||||
WrapAndTokenize(env.GetModel()->Tokenizer(), env.GetModel()->Info(),
|
|
||||||
/*pos=*/0, prompt_string);
|
|
||||||
const size_t prompt_size = prompt.size();
|
const size_t prompt_size = prompt.size();
|
||||||
|
|
||||||
std::vector<int> predicted_token_ids;
|
std::vector<int> predicted_token_ids;
|
||||||
|
|
@ -127,7 +125,6 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
||||||
// confused with the word "A".
|
// confused with the word "A".
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info;
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
.max_tokens = env.MaxTokens(),
|
|
||||||
.max_generated_tokens = 30,
|
.max_generated_tokens = 30,
|
||||||
.temperature = 0.0f,
|
.temperature = 0.0f,
|
||||||
.verbosity = env.Verbosity(),
|
.verbosity = env.Verbosity(),
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,6 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info;
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
.max_tokens = 1536,
|
|
||||||
.max_generated_tokens = 1024,
|
.max_generated_tokens = 1024,
|
||||||
.temperature = 1.0,
|
.temperature = 1.0,
|
||||||
.verbosity = 0,
|
.verbosity = 0,
|
||||||
|
|
|
||||||
|
|
@ -1117,34 +1117,15 @@ HWY_NOINLINE void Transformer(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
void RangeChecks(size_t& max_generated_tokens, const size_t prompt_size) {
|
||||||
size_t& prompt_size) {
|
|
||||||
if (!TConfig::kUseLocalAttention) {
|
if (!TConfig::kUseLocalAttention) {
|
||||||
if (max_tokens > TConfig::kSeqLen) {
|
if (max_generated_tokens > TConfig::kSeqLen) {
|
||||||
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
|
|
||||||
max_tokens, TConfig::kSeqLen);
|
|
||||||
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (max_generated_tokens > max_tokens) {
|
|
||||||
fprintf(stderr,
|
|
||||||
"WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n",
|
|
||||||
max_generated_tokens, max_tokens);
|
|
||||||
max_generated_tokens = max_tokens - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!TConfig::kUseLocalAttention) {
|
|
||||||
if (prompt_size + max_generated_tokens > max_tokens) {
|
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"WARNING: prompt_size %zu + max_generated_tokens %zu > "
|
"WARNING: max_generated_tokens %zu > kSeqLen %d, truncating.\n",
|
||||||
"max_tokens %zu, truncating to ",
|
max_generated_tokens, TConfig::kSeqLen);
|
||||||
prompt_size, max_generated_tokens, max_tokens);
|
max_generated_tokens = static_cast<size_t>(TConfig::kSeqLen);
|
||||||
prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens);
|
|
||||||
fprintf(stderr, "%zu\n", prompt_size);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_ASSERT(prompt_size > 0);
|
HWY_ASSERT(prompt_size > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1262,17 +1243,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||||
|
|
||||||
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
||||||
size_t max_tokens = runtime_config.max_tokens;
|
|
||||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
|
RangeChecks<TConfig>(max_generated_tokens, max_prompt_size);
|
||||||
for (size_t pos : queries_pos_copy) {
|
|
||||||
if (pos >= max_tokens) {
|
|
||||||
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
|
|
||||||
max_tokens);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const SampleFunc sample_token = ChooseSampleFunc<TConfig>(runtime_config);
|
const SampleFunc sample_token = ChooseSampleFunc<TConfig>(runtime_config);
|
||||||
|
|
||||||
// Prefill stops before min_prompt_size - 1 because the last prompt
|
// Prefill stops before min_prompt_size - 1 because the last prompt
|
||||||
|
|
@ -1315,7 +1287,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
}
|
}
|
||||||
|
|
||||||
const double gen_start = hwy::platform::Now();
|
const double gen_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < HWY_MIN(max_tokens, max_generated_tokens); ++gen) {
|
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||||
// Decode generates one token per query and increments queries_mutable_pos.
|
// Decode generates one token per query and increments queries_mutable_pos.
|
||||||
Transformer<TConfig>(
|
Transformer<TConfig>(
|
||||||
QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
|
QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
|
||||||
|
|
|
||||||
|
|
@ -92,8 +92,7 @@ struct RuntimeConfig {
|
||||||
return stream_token(token, prob);
|
return stream_token(token, prob);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Limits on the number of tokens generated.
|
// Limit on the number of tokens generated.
|
||||||
size_t max_tokens;
|
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
while (abs_pos < args.max_tokens) {
|
while (true) { // Loop until user quits.
|
||||||
tokens_generated_this_turn = 0;
|
tokens_generated_this_turn = 0;
|
||||||
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
|
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
|
||||||
if (!std::cin) return;
|
if (!std::cin) return;
|
||||||
|
|
@ -183,10 +183,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
||||||
timing_info);
|
timing_info);
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
}
|
}
|
||||||
std::cout
|
|
||||||
<< "max_tokens (" << args.max_tokens
|
|
||||||
<< ") exceeded. Use a larger value if desired using the --max_tokens "
|
|
||||||
<< "command line flag.\n";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,6 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -22,6 +20,7 @@
|
||||||
|
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/gemma.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
|
|
@ -63,15 +62,13 @@ void PaliGemmaTest::InitVit(const std::string& path) {
|
||||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||||
Gemma& model = *(s_env->GetModel());
|
Gemma& model = *(s_env->GetModel());
|
||||||
s_env->MutableGen().seed(0x12345678);
|
s_env->MutableGen().seed(0x12345678);
|
||||||
RuntimeConfig runtime_config = {.max_tokens = 1024,
|
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||||
.max_generated_tokens = 512,
|
|
||||||
.verbosity = 0,
|
.verbosity = 0,
|
||||||
.gen = &s_env->MutableGen()};
|
.gen = &s_env->MutableGen()};
|
||||||
runtime_config.image_tokens = image_tokens_.get();
|
runtime_config.image_tokens = image_tokens_.get();
|
||||||
size_t abs_pos = 0;
|
size_t abs_pos = 0;
|
||||||
std::string mutable_prompt = prompt_text;
|
std::string mutable_prompt = prompt_text;
|
||||||
std::vector<int> tokens =
|
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
|
||||||
WrapAndTokenize(model.Tokenizer(), model.Info(), abs_pos, mutable_prompt);
|
|
||||||
std::string response;
|
std::string response;
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
|
|
|
||||||
16
util/app.h
16
util/app.h
|
|
@ -207,7 +207,6 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
InferenceArgs() { Init(); };
|
InferenceArgs() { Init(); };
|
||||||
|
|
||||||
size_t max_tokens;
|
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
size_t prefill_tbatch_size;
|
size_t prefill_tbatch_size;
|
||||||
|
|
@ -220,21 +219,15 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() const {
|
const char* Validate() const {
|
||||||
if (max_tokens > gcpp::kSeqLen) {
|
if (max_generated_tokens > gcpp::kSeqLen) {
|
||||||
return "max_tokens is larger than the maximum sequence length (see "
|
return "max_generated_tokens is larger than the maximum sequence length "
|
||||||
"configs.h).";
|
"(see configs.h).";
|
||||||
}
|
|
||||||
if (max_generated_tokens > max_tokens) {
|
|
||||||
return "Maximum number of generated tokens is larger than the maximum "
|
|
||||||
"total tokens.";
|
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
visitor(max_tokens, "max_tokens", size_t{3072},
|
|
||||||
"Maximum number of tokens in prompt + generation.");
|
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||||
"Maximum number of tokens to generate.");
|
"Maximum number of tokens to generate.");
|
||||||
|
|
||||||
|
|
@ -255,12 +248,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||||
runtime_config.max_tokens = max_tokens;
|
|
||||||
runtime_config.max_generated_tokens = max_generated_tokens;
|
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||||
|
|
||||||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||||
|
|
||||||
runtime_config.temperature = temperature;
|
runtime_config.temperature = temperature;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue