Use Loader/AppArgs to construct gemma_test model, simplify AcceptFunc

accept_token: allow default, check if empty when using
allow mixing sample_func and stream_func, call the latter after the former
Also fix missing includes/deps.
PiperOrigin-RevId: 642240012
This commit is contained in:
Jan Wassenberg 2024-06-11 05:52:32 -07:00 committed by Copybara-Service
parent a0e808e341
commit 3e2396f98c
15 changed files with 172 additions and 250 deletions

View File

@ -123,6 +123,7 @@ cc_library(
deps = [
":common",
":gemma_lib",
"@hwy//:hwy",
],
)
@ -159,11 +160,13 @@ cc_test(
"no_tap",
],
deps = [
":app",
":args",
":common",
":cross_entropy",
":gemma_lib",
":ops",
# "//base",
"@googletest//:gtest_main",
"//compression:io",
"@hwy//:hwy_test_util",
@ -355,6 +358,8 @@ cc_library(
deps = [
":common",
":weights",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:thread_pool",
],
)

View File

@ -50,11 +50,6 @@ TEST(OptimizeTest, GradientDescent) {
ByteStorageT backward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
KVCache kv_cache = KVCache::Create(model_type);
size_t max_tokens = 32;
size_t max_generated_tokens = 16;
float temperature = 1.0f;
int verbosity = 0;
const auto accept_token = [](int) { return true; };
Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool);
@ -65,8 +60,13 @@ TEST(OptimizeTest, GradientDescent) {
return token != ReverseSequenceSampler::kEndToken;
};
RuntimeConfig runtime = {
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
stream_token, accept_token, nullptr, ReverseSequenceSampler::kEndToken,
.max_tokens = 32,
.max_generated_tokens = 16,
.temperature = 1.0f,
.verbosity = 0,
.gen = &gen,
.stream_token = stream_token,
.eos_id = ReverseSequenceSampler::kEndToken,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);

View File

@ -18,8 +18,10 @@
#include <cmath>
#include <random>
#include "compression/compress.h"
#include "gemma/common.h"
#include "gemma/weights.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {

View File

@ -44,7 +44,6 @@ std::pair<std::string, int> QueryModel(
prompt.insert(prompt.begin(), 2);
std::string res;
size_t total_tokens = 0;
auto accept_token = [](int) { return true; };
std::mt19937 gen;
gen.seed(42);
@ -67,7 +66,6 @@ std::pair<std::string, int> QueryModel(
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
layers_output);

View File

@ -13,12 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <stddef.h>
#include <random>
#include <string>
#include <vector>
#include "third_party/gemma_cpp/gemma.h"
#include "util/app.h" // LoaderArgs
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "util/args.h"
std::vector<int> tokenize(const std::string& prompt_string,
const gcpp::GemmaTokenizer* tokenizer) {
@ -26,19 +30,16 @@ std::vector<int> tokenize(const std::string& prompt_string,
"<end_of_turn>\n<start_of_turn>model\n";
std::vector<int> tokens;
HWY_ASSERT(tokenizer->Encode(formatted, &tokens));
tokens.insert(tokens.begin(), 2); // BOS token
tokens.insert(tokens.begin(), BOS_ID);
return tokens;
}
int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
// Rough heuristic for the number of threads to use
size_t num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
hwy::ThreadPool pool(num_threads);
gcpp::AppArgs app(argc, argv);
// Instantiate model and KV Cache
hwy::ThreadPool pool(app.num_threads);
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
size_t pos = 0; // KV Cache position
@ -53,7 +54,7 @@ int main(int argc, char** argv) {
tokenize("Write a greeting to the world.", model.Tokenizer());
size_t ntokens = tokens.size();
// This callback function gets invoked everytime a token is generated
// This callback function gets invoked every time a token is generated
auto stream_token = [&pos, &ntokens, tokenizer = model.Tokenizer()](int token,
float) {
++pos;
@ -74,5 +75,4 @@ int main(int argc, char** argv) {
.verbosity = 0},
tokens, /*KV cache position = */ 0, kv_cache, pool,
stream_token, gen);
std::cout << "\n";
}

View File

@ -72,7 +72,6 @@ std::pair<std::string, int> QueryModel(
prompt.insert(prompt.begin(), 2);
std::string res;
size_t total_tokens = 0;
auto accept_token = [](int) { return true; };
std::mt19937 gen;
gen.seed(42);
@ -100,7 +99,6 @@ std::pair<std::string, int> QueryModel(
.verbosity = app.verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
/*layers_output=*/nullptr);

View File

@ -32,6 +32,7 @@
#include <stdio.h>
#include <algorithm> // std::clamp
#include <cstdlib>
#include <iostream>
#include <string>
#include <thread> // NOLINT
@ -41,8 +42,8 @@
#include "gemma/weights.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/profiler.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
namespace gcpp {

View File

@ -20,7 +20,6 @@
#include <algorithm>
#include <cmath>
#include <functional>
#include <regex> // NOLINT
#include <string>
#include <utility>
@ -28,6 +27,7 @@
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "hwy/base.h"
namespace gcpp {
@ -67,51 +67,53 @@ void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity) {
auto stream_token = [](int, float) { return true; };
auto accept_token = [](int) { return true; };
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
// TWeight is unused, but we have to pass it to Config*.
const int vocab_size =
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.ModelType());
float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1;
std::function<int(const float*, size_t)> sample_token =
[&](const float* probs, size_t vocab_size) -> int {
const int token = prompt[pos];
const float prob = probs[token];
cross_entropy -= std::max(std::log(prob), -64.0f);
const SampleFunc sample_token = [&](const float* probs,
size_t vocab_size) -> int {
// We are called for each token, but pos starts at 1. Clamping max_tokens
// to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size());
const int token = prompt[pos];
const float prob = probs[token];
cross_entropy -= std::max(std::log(prob), -64.0f);
if (verbosity >= 4) {
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
}
if (verbosity >= 3) {
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos,
token, TokenString(gemma.Tokenizer(), token).c_str(), prob,
-std::log(prob) / std::log(2.0));
}
if (verbosity >= 2 && pos % 100 == 99) {
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
cross_entropy / std::log(2.0) / (pos + 1));
}
++pos;
return token;
};
if (verbosity >= 4) {
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
}
if (verbosity >= 3) {
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
TokenString(gemma.Tokenizer(), token).c_str(), prob,
-std::log(prob) / std::log(2.0));
}
if (verbosity >= 2 && pos % 100 == 99) {
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
cross_entropy / std::log(2.0) / (pos + 1));
}
++pos;
return token;
};
std::vector<int> prompt0 = { prompt[0] };
max_tokens = HWY_MIN(max_tokens, prompt.size());
RuntimeConfig runtime = {
.max_tokens = max_tokens,
.max_generated_tokens = max_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,
.stream_token = stream_token,
.accept_token = accept_token,
.sample_func = &sample_token,
.max_tokens = max_tokens,
.max_generated_tokens = max_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,
.stream_token = stream_token,
.sample_func = sample_token,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr);
const float scale = 1.0 / std::log(2.0);
const float scale = 1.0f / std::log(2.0f);
return cross_entropy * scale;
}

View File

@ -781,6 +781,15 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
}
HWY_ASSERT(prompt_size > 0);
const SampleFunc sample_token =
runtime_config.sample_func
? runtime_config.sample_func
: [&](const float* logits, size_t vocab_size) -> int {
return SampleTopK<TConfig::kTopK>(logits, vocab_size, *runtime_config.gen,
runtime_config.temperature,
runtime_config.accept_token);
};
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
//
// After the first turn, pos gets passed in with > 0 corresponding to the
@ -844,16 +853,9 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize);
if (runtime_config.sample_func) {
token = (*runtime_config.sample_func)(activations.logits.data(),
kVocabSize);
} else {
token = SampleTopK<TConfig::kTopK>(
activations.logits.data(), kVocabSize, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
if (!runtime_config.stream_token(token, activations.logits[token])) {
token = runtime_config.eos_id;
}
token = sample_token(activations.logits.data(), kVocabSize);
if (!runtime_config.stream_token(token, activations.logits[token])) {
token = runtime_config.eos_id;
}
if (generate_pos == 0) {
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;

View File

@ -68,15 +68,15 @@ class GemmaTokenizer {
};
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. StreamFunc should return False to stop generation and
// True to continue generation.
// probability is 0.0f. StreamFunc should return false to stop generation and
// true to continue generation.
using StreamFunc = std::function<bool(int, float)>;
// AcceptFunc is called with token. It should return False for tokens you don't
// want to generate and True for tokens you want to generate.
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>;
// CustomSampleFunc is called with the probability distribution for the next
// token, and its return value is used as the next generated token.
using CustomSampleFunc = std::function<int(const float*, size_t)>;
// If not empty, SampleFunc is called with the probability distribution for the
// next token, and its return value is used as the next generated token.
using SampleFunc = std::function<int(const float*, size_t)>;
struct RuntimeConfig {
size_t max_tokens;
@ -84,9 +84,9 @@ struct RuntimeConfig {
float temperature;
int verbosity;
std::mt19937* gen;
const StreamFunc& stream_token;
const AcceptFunc& accept_token;
const CustomSampleFunc* sample_func = nullptr;
StreamFunc stream_token;
AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK.
int eos_id = EOS_ID;
};

View File

@ -15,31 +15,43 @@
#include "gemma/gemma.h"
#include <algorithm>
#include <iostream>
#include <stdio.h>
#include <memory>
#include <random>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "compression/io.h" // Path
#include "gemma/common.h"
// Placeholder for internal header, do not modify.
#include "gemma/cross_entropy.h"
#include "gemma/ops.h"
#include "util/app.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h"
namespace gcpp {
namespace {
int s_argc = 0;
char** s_argv = nullptr;
class GemmaTest : public ::testing::Test {
protected:
GemmaTest()
: weights("./2b-it-mqa.sbs"),
tokenizer("./tokenizer.spm"),
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
model(tokenizer, weights, model_type, weight_type, pool) {
KVCache kv_cache = KVCache::Create(model_type);
static void SetUpTestSuite() {
gcpp::LoaderArgs loader(s_argc, s_argv);
gcpp::AppArgs app(s_argc, s_argv);
if (const char* err = loader.Validate()) {
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n");
} else {
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads);
s_model = AllocateGemma(loader, *s_pool);
s_kv_cache = KVCache::Create(loader.ModelType());
}
}
static void TearDownTestSuite() {
s_pool.reset();
s_model.reset();
}
std::string GemmaReply(const std::string& prompt_string) {
@ -47,10 +59,10 @@ class GemmaTest : public ::testing::Test {
gen.seed(42);
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), 2);
prompt.insert(prompt.begin(), BOS_ID);
std::vector<int> response;
auto stream_token = [&response](int token, float) {
@ -64,65 +76,65 @@ class GemmaTest : public ::testing::Test {
.verbosity = 0,
.gen = &gen,
.stream_token = stream_token,
.accept_token = [](int) { return true; },
};
gcpp::TimingInfo timing_info;
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache,
timing_info, /*layers_output=*/nullptr);
s_model->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
timing_info, /*layers_output=*/nullptr);
std::string response_text;
HWY_ASSERT(model.Tokenizer().Decode(response, &response_text));
HWY_ASSERT(s_model->Tokenizer().Decode(response, &response_text));
return response_text;
}
float GemmaCrossEntropy(const std::string& prompt_string) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
return ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, kv_cache,
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt));
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*s_model, /*max_tokens=*/3072, prompt,
s_kv_cache,
/*verbosity=*/0) /
prompt_string.size();
prompt_string.size();
}
void TestQuestions(const char* kQA[][2], size_t num_questions) {
if (!s_model) return;
for (size_t i = 0; i < num_questions; ++i) {
std::cout << "Question " << i + 1 << "\n\n";
fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]);
std::cout << response << "\n\n";
fprintf(stderr, "'%s'\n\n", response.c_str());
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
}
gcpp::Path weights;
gcpp::Path tokenizer;
gcpp::KVCache kv_cache;
hwy::ThreadPool pool;
gcpp::Model model_type = gcpp::Model::GEMMA_2B;
gcpp::Type weight_type = gcpp::Type::kSFP;
gcpp::Gemma model;
static std::unique_ptr<hwy::ThreadPool> s_pool;
static std::unique_ptr<gcpp::Gemma> s_model;
static gcpp::KVCache s_kv_cache;
};
TEST_F(GemmaTest, DISABLED_Geography) {
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_model;
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
TEST_F(GemmaTest, Geography) {
static const char* kQA[][2] = {
{"What is the capital of Hungary?", "Budapest"},
{"How many states does the US have?", "50"},
{"list me ten biggest cities in the world", "Tokyo"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum);
}
TEST_F(GemmaTest, DISABLED_History) {
TEST_F(GemmaTest, History) {
static const char* kQA[][2] = {
{"When was the Battle of Hastings?", "1066"},
{"Who fought at the Battle of Marathon?", "Greek"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum);
}
TEST_F(GemmaTest, DISABLED_Arithmetic) {
TEST_F(GemmaTest, Arithmetic) {
static const char* kQA[][2] = {
{"what is 13 + 14?", "27"},
{"what is 7 * 8", "56"},
{"what is 7 * 8?", "56"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum);
@ -163,150 +175,44 @@ static const char kGettysburg[] = {
"a new birth of freedom; and that this government of the people, by the "
"people, for the people, shall not perish from the earth.\n"};
// The Declaration of Independence.
static const char kDeclaration[] = {
"IN CONGRESS, July 4, 1776.\n\nThe unanimous Declaration of the thirteen "
"united States of America,\n\nWhen in the Course of human events, it "
"becomes necessary for one people to dissolve the political bands which "
"have connected them with another, and to assume among the powers of the "
"earth, the separate and equal station to which the Laws of Nature and of "
"Nature's God entitle them, a decent respect to the opinions of mankind "
"requires that they should declare the causes which impel them to the "
"separation.\n\nWe hold these truths to be self-evident, that all men are "
"created equal, that they are endowed by their Creator with certain "
"unalienable Rights, that among these are Life, Liberty and the pursuit of "
"Happiness.--That to secure these rights, Governments are instituted among "
"Men, deriving their just powers from the consent of the governed, --That "
"whenever any Form of Government becomes destructive of these ends, it is "
"the Right of the People to alter or to abolish it, and to institute new "
"Government, laying its foundation on such principles and organizing its "
"powers in such form, as to them shall seem most likely to effect their "
"Safety and Happiness. Prudence, indeed, will dictate that Governments "
"long established should not be changed for light and transient causes; "
"and accordingly all experience hath shewn, that mankind are more disposed "
"to suffer, while evils are sufferable, than to right themselves by "
"abolishing the forms to which they are accustomed. But when a long train "
"of abuses and usurpations, pursuing invariably the same Object evinces a "
"design to reduce them under absolute Despotism, it is their right, it is "
"their duty, to throw off such Government, and to provide new Guards for "
"their future security.--Such has been the patient sufferance of these "
"Colonies; and such is now the necessity which constrains them to alter "
"their former Systems of Government. The history of the present King of "
"Great Britain is a history of repeated injuries and usurpations, all "
"having in direct object the establishment of an absolute Tyranny over "
"these States. To prove this, let Facts be submitted to a candid "
"world.\n\nHe has refused his Assent to Laws, the most wholesome and "
"necessary for the public good.\nHe has forbidden his Governors to pass "
"Laws of immediate and pressing importance, unless suspended in their "
"operation till his Assent should be obtained; and when so suspended, he "
"has utterly neglected to attend to them.\nHe has refused to pass other "
"Laws for the accommodation of large districts of people, unless those "
"people would relinquish the right of Representation in the Legislature, a "
"right inestimable to them and formidable to tyrants only.\nHe has called "
"together legislative bodies at places unusual, uncomfortable, and distant "
"from the depository of their public Records, for the sole purpose of "
"fatiguing them into compliance with his measures.\nHe has dissolved "
"Representative Houses repeatedly, for opposing with manly firmness his "
"invasions on the rights of the people.\nHe has refused for a long time, "
"after such dissolutions, to cause others to be elected; whereby the "
"Legislative powers, incapable of Annihilation, have returned to the "
"People at large for their exercise; the State remaining in the mean time "
"exposed to all the dangers of invasion from without, and convulsions "
"within.\nHe has endeavoured to prevent the population of these States; "
"for that purpose obstructing the Laws for Naturalization of Foreigners; "
"refusing to pass others to encourage their migrations hither, and raising "
"the conditions of new Appropriations of Lands.\nHe has obstructed the "
"Administration of Justice, by refusing his Assent to Laws for "
"establishing Judiciary powers.\nHe has made Judges dependent on his Will "
"alone, for the tenure of their offices, and the amount and payment of "
"their salaries.\nHe has erected a multitude of New Offices, and sent "
"hither swarms of Officers to harrass our people, and eat out their "
"substance.\nHe has kept among us, in times of peace, Standing Armies "
"without the Consent of our legislatures.\nHe has affected to render the "
"Military independent of and superior to the Civil power.\nHe has combined "
"with others to subject us to a jurisdiction foreign to our constitution, "
"and unacknowledged by our laws; giving his Assent to their Acts of "
"pretended Legislation:\nFor Quartering large bodies of armed troops among "
"us:\nFor protecting them, by a mock Trial, from punishment for any "
"Murders which they should commit on the Inhabitants of these States:\nFor "
"cutting off our Trade with all parts of the world:\nFor imposing Taxes on "
"us without our Consent:\nFor depriving us in many cases, of the benefits "
"of Trial by Jury:\nFor transporting us beyond Seas to be tried for "
"pretended offences\nFor abolishing the free System of English Laws in a "
"neighbouring Province, establishing therein an Arbitrary government, and "
"enlarging its Boundaries so as to render it at once an example and fit "
"instrument for introducing the same absolute rule into these "
"Colonies:\nFor taking away our Charters, abolishing our most valuable "
"Laws, and altering fundamentally the Forms of our Governments:\nFor "
"suspending our own Legislatures, and declaring themselves invested with "
"power to legislate for us in all cases whatsoever.\nHe has abdicated "
"Government here, by declaring us out of his Protection and waging War "
"against us.\nHe has plundered our seas, ravaged our Coasts, burnt our "
"towns, and destroyed the lives of our people.\nHe is at this time "
"transporting large Armies of foreign Mercenaries to compleat the works of "
"death, desolation and tyranny, already begun with circumstances of "
"Cruelty & perfidy scarcely paralleled in the most barbarous ages, and "
"totally unworthy the Head of a civilized nation.\nHe has constrained our "
"fellow Citizens taken Captive on the high Seas to bear Arms against their "
"Country, to become the executioners of their friends and Brethren, or to "
"fall themselves by their Hands.\nHe has excited domestic insurrections "
"amongst us, and has endeavoured to bring on the inhabitants of our "
"frontiers, the merciless Indian Savages, whose known rule of warfare, is "
"an undistinguished destruction of all ages, sexes and conditions.\n\nIn "
"every stage of these Oppressions We have Petitioned for Redress in the "
"most humble terms: Our repeated Petitions have been answered only by "
"repeated injury. A Prince whose character is thus marked by every act "
"which may define a Tyrant, is unfit to be the ruler of a free "
"people.\n\nNor have We been wanting in attentions to our Brittish "
"brethren. We have warned them from time to time of attempts by their "
"legislature to extend an unwarrantable jurisdiction over us. We have "
"reminded them of the circumstances of our emigration and settlement here. "
"We have appealed to their native justice and magnanimity, and we have "
"conjured them by the ties of our common kindred to disavow these "
"usurpations, which, would inevitably interrupt our connections and "
"correspondence. They too have been deaf to the voice of justice and of "
"consanguinity. We must, therefore, acquiesce in the necessity, which "
"denounces our Separation, and hold them, as we hold the rest of mankind, "
"Enemies in War, in Peace Friends.\n\nWe, therefore, the Representatives "
"of the united States of America, in General Congress, Assembled, "
"appealing to the Supreme Judge of the world for the rectitude of our "
"intentions, do, in the Name, and by Authority of the good People of these "
"Colonies, solemnly publish and declare, That these United Colonies are, "
"and of Right ought to be Free and Independent States; that they are "
"Absolved from all Allegiance to the British Crown, and that all political "
"connection between them and the State of Great Britain, is and ought to "
"be totally dissolved; and that as Free and Independent States, they have "
"full Power to levy War, conclude Peace, contract Alliances, establish "
"Commerce, and to do all other Acts and Things which Independent States "
"may of right do. And for the support of this Declaration, with a firm "
"reliance on the protection of divine Providence, we mutually pledge to "
"each other our Lives, our Fortunes and our sacred Honor.\n"};
TEST_F(GemmaTest, DISABLED_CrossEntropySmall) {
TEST_F(GemmaTest, CrossEntropySmall) {
if (!s_model) return;
static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = GemmaCrossEntropy(kSmall);
std::cout << "per-byte entropy: " << entropy << "\n";
EXPECT_LT(entropy, 1.6f);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
// Note that entropy is 3x higher for the 7b-it model.
EXPECT_LT(entropy, 1.7f);
}
TEST_F(GemmaTest, DISABLED_CrossEntropyJingleBells) {
TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_model) return;
float entropy = GemmaCrossEntropy(kJingleBells);
std::cout << "per-byte entropy: " << entropy << "\n";
EXPECT_LT(entropy, 2.3f);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, 1.7f);
}
TEST_F(GemmaTest, DISABLED_CrossEntropyGettysburg) {
TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_model) return;
float entropy = GemmaCrossEntropy(kGettysburg);
std::cout << "per-byte entropy: " << entropy << "\n";
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, 1.2f);
}
TEST_F(GemmaTest, DISABLED_CrossEntropyDeclaration) {
float entropy = GemmaCrossEntropy(kDeclaration);
std::cout << "per-byte entropy: " << entropy << "\n";
EXPECT_LT(entropy, 1.0f);
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
// For later use by SetUp.
gcpp::s_argc = argc;
gcpp::s_argv = argv;
// Probably should be called before SetUpTestSuite.
testing::InitGoogleTest(&gcpp::s_argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -1340,11 +1340,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast<int>(i))) {
if (probabilities[i] < top_k[k - 1] &&
(!accept_token || accept_token(StaticCast<int>(i)))) {
continue;
}
for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] && accept_token(StaticCast<int>(i))) {
if (probabilities[i] > top_k[j] &&
(!accept_token || accept_token(StaticCast<int>(i)))) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];

View File

@ -284,10 +284,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
std::cout << "\n" << instructions << "\n";
}
ReplGemma(
model, loader.ModelTrainingType(), kv_cache, pool, inference,
app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line);
ReplGemma(model, loader.ModelTrainingType(), kv_cache, pool, inference,
app.verbosity, AcceptFunc(), app.eot_line);
}
} // namespace gcpp

View File

@ -98,8 +98,8 @@ void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
return true;
};
auto accept_token = [&current_pos, &prompt_size,
&accept_token_set](int token) {
const AcceptFunc accept_token = [&current_pos, &prompt_size,
&accept_token_set](int token) {
// i.e. we have no constraints on accepted tokens
if (accept_token_set.empty()) {
return true;

View File

@ -18,6 +18,8 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#include <memory>
#include "hwy/contrib/thread_pool/thread_pool.h"
#if HWY_OS_LINUX
#include <sched.h>
@ -242,6 +244,12 @@ static inline Gemma CreateGemma(const LoaderArgs& loader,
loader.WeightType(), pool);
}
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) {
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.ModelType(), loader.WeightType(), pool);
}
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }