diff --git a/BUILD.bazel b/BUILD.bazel index 2449a5d..5f51fe1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -123,6 +123,7 @@ cc_library( deps = [ ":common", ":gemma_lib", + "@hwy//:hwy", ], ) @@ -159,11 +160,13 @@ cc_test( "no_tap", ], deps = [ + ":app", ":args", ":common", ":cross_entropy", ":gemma_lib", ":ops", + # "//base", "@googletest//:gtest_main", "//compression:io", "@hwy//:hwy_test_util", @@ -355,6 +358,8 @@ cc_library( deps = [ ":common", ":weights", + "//compression:compress", + "@hwy//:hwy", "@hwy//:thread_pool", ], ) diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 8195b98..5a8e343 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -50,11 +50,6 @@ TEST(OptimizeTest, GradientDescent) { ByteStorageT backward = CallForModelAndWeight(model_type, weight_type); KVCache kv_cache = KVCache::Create(model_type); - size_t max_tokens = 32; - size_t max_generated_tokens = 16; - float temperature = 1.0f; - int verbosity = 0; - const auto accept_token = [](int) { return true; }; Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool); @@ -65,8 +60,13 @@ TEST(OptimizeTest, GradientDescent) { return token != ReverseSequenceSampler::kEndToken; }; RuntimeConfig runtime = { - max_tokens, max_generated_tokens, temperature, verbosity, &gen, - stream_token, accept_token, nullptr, ReverseSequenceSampler::kEndToken, + .max_tokens = 32, + .max_generated_tokens = 16, + .temperature = 1.0f, + .verbosity = 0, + .gen = &gen, + .stream_token = stream_token, + .eos_id = ReverseSequenceSampler::kEndToken, }; TimingInfo timing_info; gemma.Generate(runtime, prompt, 0, kv_cache, timing_info); diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 198bb90..ed51183 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -18,8 +18,10 @@ #include #include +#include "compression/compress.h" #include "gemma/common.h" #include "gemma/weights.h" +#include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { diff --git a/debug_prompt.cc b/debug_prompt.cc index cb49fe3..36f824c 100644 --- a/debug_prompt.cc +++ b/debug_prompt.cc @@ -44,7 +44,6 @@ std::pair QueryModel( prompt.insert(prompt.begin(), 2); std::string res; size_t total_tokens = 0; - auto accept_token = [](int) { return true; }; std::mt19937 gen; gen.seed(42); @@ -67,7 +66,6 @@ std::pair QueryModel( .verbosity = app.verbosity, .gen = &gen, .stream_token = stream_token, - .accept_token = accept_token, }; model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info, layers_output); diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index ac917af..5984fa6 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -13,12 +13,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include + +#include +#include +#include #include "third_party/gemma_cpp/gemma.h" #include "util/app.h" // LoaderArgs +#include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "util/args.h" std::vector tokenize(const std::string& prompt_string, const gcpp::GemmaTokenizer* tokenizer) { @@ -26,19 +30,16 @@ std::vector tokenize(const std::string& prompt_string, "\nmodel\n"; std::vector tokens; HWY_ASSERT(tokenizer->Encode(formatted, &tokens)); - tokens.insert(tokens.begin(), 2); // BOS token + tokens.insert(tokens.begin(), BOS_ID); return tokens; } int main(int argc, char** argv) { gcpp::LoaderArgs loader(argc, argv); - - // Rough heuristic for the number of threads to use - size_t num_threads = static_cast(std::clamp( - static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); - hwy::ThreadPool pool(num_threads); + gcpp::AppArgs app(argc, argv); // Instantiate model and KV Cache + hwy::ThreadPool pool(app.num_threads); gcpp::Gemma model = gcpp::CreateGemma(loader, pool); gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); size_t pos = 0; // KV Cache position @@ -53,7 +54,7 @@ int main(int argc, char** argv) { tokenize("Write a greeting to the world.", model.Tokenizer()); size_t ntokens = tokens.size(); - // This callback function gets invoked everytime a token is generated + // This callback function gets invoked every time a token is generated auto stream_token = [&pos, &ntokens, tokenizer = model.Tokenizer()](int token, float) { ++pos; @@ -74,5 +75,4 @@ int main(int argc, char** argv) { .verbosity = 0}, tokens, /*KV cache position = */ 0, kv_cache, pool, stream_token, gen); - std::cout << "\n"; } diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index db4a3ef..e8d9539 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -72,7 +72,6 @@ std::pair QueryModel( prompt.insert(prompt.begin(), 2); std::string res; size_t total_tokens = 0; - auto accept_token = [](int) { return true; }; std::mt19937 gen; gen.seed(42); @@ -100,7 +99,6 @@ std::pair QueryModel( .verbosity = app.verbosity, .gen = &gen, .stream_token = stream_token, - .accept_token = accept_token, }; model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info, /*layers_output=*/nullptr); diff --git a/gemma/compress_weights.cc b/gemma/compress_weights.cc index d92b2b3..8552802 100644 --- a/gemma/compress_weights.cc +++ b/gemma/compress_weights.cc @@ -32,6 +32,7 @@ #include #include // std::clamp +#include #include #include #include // NOLINT @@ -41,8 +42,8 @@ #include "gemma/weights.h" #include "util/args.h" #include "hwy/base.h" -#include "hwy/profiler.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" namespace gcpp { diff --git a/gemma/cross_entropy.cc b/gemma/cross_entropy.cc index 985b0b6..9d345e2 100644 --- a/gemma/cross_entropy.cc +++ b/gemma/cross_entropy.cc @@ -20,7 +20,6 @@ #include #include -#include #include // NOLINT #include #include @@ -28,6 +27,7 @@ #include "gemma/common.h" #include "gemma/gemma.h" +#include "hwy/base.h" namespace gcpp { @@ -67,51 +67,53 @@ void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len, float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, int verbosity) { - auto stream_token = [](int, float) { return true; }; - auto accept_token = [](int) { return true; }; + const StreamFunc stream_token = [](int /*token*/, float) { return true; }; // TWeight is unused, but we have to pass it to Config*. const int vocab_size = CallForModel(gemma.ModelType()); float cross_entropy = std::log(vocab_size); // first token size_t pos = 1; - std::function sample_token = - [&](const float* probs, size_t vocab_size) -> int { - const int token = prompt[pos]; - const float prob = probs[token]; - cross_entropy -= std::max(std::log(prob), -64.0f); + const SampleFunc sample_token = [&](const float* probs, + size_t vocab_size) -> int { + // We are called for each token, but pos starts at 1. Clamping max_tokens + // to prompt.size() should prevent overrun. + HWY_ASSERT(pos < prompt.size()); + const int token = prompt[pos]; + const float prob = probs[token]; + cross_entropy -= std::max(std::log(prob), -64.0f); - if (verbosity >= 4) { - LogTopK(gemma.Tokenizer(), probs, vocab_size, 10); - } - if (verbosity >= 3) { - printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, - token, TokenString(gemma.Tokenizer(), token).c_str(), prob, - -std::log(prob) / std::log(2.0)); - } - if (verbosity >= 2 && pos % 100 == 99) { - printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, - cross_entropy / std::log(2.0) / (pos + 1)); - } - ++pos; - return token; - }; + if (verbosity >= 4) { + LogTopK(gemma.Tokenizer(), probs, vocab_size, 10); + } + if (verbosity >= 3) { + printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token, + TokenString(gemma.Tokenizer(), token).c_str(), prob, + -std::log(prob) / std::log(2.0)); + } + if (verbosity >= 2 && pos % 100 == 99) { + printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, + cross_entropy / std::log(2.0) / (pos + 1)); + } + ++pos; + return token; + }; std::vector prompt0 = { prompt[0] }; + max_tokens = HWY_MIN(max_tokens, prompt.size()); RuntimeConfig runtime = { - .max_tokens = max_tokens, - .max_generated_tokens = max_tokens - 1, - .temperature = 0.0f, - .verbosity = verbosity, - .gen = nullptr, - .stream_token = stream_token, - .accept_token = accept_token, - .sample_func = &sample_token, + .max_tokens = max_tokens, + .max_generated_tokens = max_tokens - 1, + .temperature = 0.0f, + .verbosity = verbosity, + .gen = nullptr, + .stream_token = stream_token, + .sample_func = sample_token, }; TimingInfo timing_info; gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr); - const float scale = 1.0 / std::log(2.0); + const float scale = 1.0f / std::log(2.0f); return cross_entropy * scale; } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index ff551c2..4ca5f20 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -781,6 +781,15 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, } HWY_ASSERT(prompt_size > 0); + const SampleFunc sample_token = + runtime_config.sample_func + ? runtime_config.sample_func + : [&](const float* logits, size_t vocab_size) -> int { + return SampleTopK(logits, vocab_size, *runtime_config.gen, + runtime_config.temperature, + runtime_config.accept_token); + }; + // pos indexes the KV cache. In the first turn of a chat, pos = 0. // // After the first turn, pos gets passed in with > 0 corresponding to the @@ -844,16 +853,9 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); - if (runtime_config.sample_func) { - token = (*runtime_config.sample_func)(activations.logits.data(), - kVocabSize); - } else { - token = SampleTopK( - activations.logits.data(), kVocabSize, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token); - if (!runtime_config.stream_token(token, activations.logits[token])) { - token = runtime_config.eos_id; - } + token = sample_token(activations.logits.data(), kVocabSize); + if (!runtime_config.stream_token(token, activations.logits[token])) { + token = runtime_config.eos_id; } if (generate_pos == 0) { timing_info.time_to_first_token = hwy::platform::Now() - gen_start; diff --git a/gemma/gemma.h b/gemma/gemma.h index 60a0f04..4e896ed 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -68,15 +68,15 @@ class GemmaTokenizer { }; // StreamFunc is called with (token, probability). For prompt tokens, -// probability is 0.0f. StreamFunc should return False to stop generation and -// True to continue generation. +// probability is 0.0f. StreamFunc should return false to stop generation and +// true to continue generation. using StreamFunc = std::function; -// AcceptFunc is called with token. It should return False for tokens you don't -// want to generate and True for tokens you want to generate. +// If not empty, AcceptFunc is called with token. It should return false for +// tokens you don't want to generate and true for tokens you want to generate. using AcceptFunc = std::function; -// CustomSampleFunc is called with the probability distribution for the next -// token, and its return value is used as the next generated token. -using CustomSampleFunc = std::function; +// If not empty, SampleFunc is called with the probability distribution for the +// next token, and its return value is used as the next generated token. +using SampleFunc = std::function; struct RuntimeConfig { size_t max_tokens; @@ -84,9 +84,9 @@ struct RuntimeConfig { float temperature; int verbosity; std::mt19937* gen; - const StreamFunc& stream_token; - const AcceptFunc& accept_token; - const CustomSampleFunc* sample_func = nullptr; + StreamFunc stream_token; + AcceptFunc accept_token; // if empty, accepts all tokens. + SampleFunc sample_func; // if empty, uses SampleTopK. int eos_id = EOS_ID; }; diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index 6f194a1..4086cbf 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -15,31 +15,43 @@ #include "gemma/gemma.h" -#include -#include +#include + +#include #include #include -#include // NOLINT #include -#include "compression/io.h" // Path -#include "gemma/common.h" +// Placeholder for internal header, do not modify. #include "gemma/cross_entropy.h" #include "gemma/ops.h" +#include "util/app.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/tests/test_util-inl.h" namespace gcpp { namespace { +int s_argc = 0; +char** s_argv = nullptr; + class GemmaTest : public ::testing::Test { protected: - GemmaTest() - : weights("./2b-it-mqa.sbs"), - tokenizer("./tokenizer.spm"), - pool(std::min(20, (std::thread::hardware_concurrency() - 1) / 2)), - model(tokenizer, weights, model_type, weight_type, pool) { - KVCache kv_cache = KVCache::Create(model_type); + static void SetUpTestSuite() { + gcpp::LoaderArgs loader(s_argc, s_argv); + gcpp::AppArgs app(s_argc, s_argv); + if (const char* err = loader.Validate()) { + fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n"); + } else { + s_pool = std::make_unique(app.num_threads); + s_model = AllocateGemma(loader, *s_pool); + s_kv_cache = KVCache::Create(loader.ModelType()); + } + } + + static void TearDownTestSuite() { + s_pool.reset(); + s_model.reset(); } std::string GemmaReply(const std::string& prompt_string) { @@ -47,10 +59,10 @@ class GemmaTest : public ::testing::Test { gen.seed(42); std::vector prompt; - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt)); + HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt)); // For both pre-trained and instruction-tuned models: prepend "" token // if needed. - prompt.insert(prompt.begin(), 2); + prompt.insert(prompt.begin(), BOS_ID); std::vector response; auto stream_token = [&response](int token, float) { @@ -64,65 +76,65 @@ class GemmaTest : public ::testing::Test { .verbosity = 0, .gen = &gen, .stream_token = stream_token, - .accept_token = [](int) { return true; }, }; gcpp::TimingInfo timing_info; - model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, - timing_info, /*layers_output=*/nullptr); + s_model->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache, + timing_info, /*layers_output=*/nullptr); std::string response_text; - HWY_ASSERT(model.Tokenizer().Decode(response, &response_text)); + HWY_ASSERT(s_model->Tokenizer().Decode(response, &response_text)); return response_text; } float GemmaCrossEntropy(const std::string& prompt_string) { std::vector prompt; - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt)); - return ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, kv_cache, + HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt)); + prompt.insert(prompt.begin(), BOS_ID); + return ComputeCrossEntropy(*s_model, /*max_tokens=*/3072, prompt, + s_kv_cache, /*verbosity=*/0) / - prompt_string.size(); + prompt_string.size(); } void TestQuestions(const char* kQA[][2], size_t num_questions) { + if (!s_model) return; for (size_t i = 0; i < num_questions; ++i) { - std::cout << "Question " << i + 1 << "\n\n"; + fprintf(stderr, "Question %zu\n\n", i + 1); std::string response = GemmaReply(kQA[i][0]); - std::cout << response << "\n\n"; + fprintf(stderr, "'%s'\n\n", response.c_str()); EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT } } - gcpp::Path weights; - gcpp::Path tokenizer; - gcpp::KVCache kv_cache; - hwy::ThreadPool pool; - gcpp::Model model_type = gcpp::Model::GEMMA_2B; - gcpp::Type weight_type = gcpp::Type::kSFP; - gcpp::Gemma model; + static std::unique_ptr s_pool; + static std::unique_ptr s_model; + static gcpp::KVCache s_kv_cache; }; -TEST_F(GemmaTest, DISABLED_Geography) { +/*static*/ std::unique_ptr GemmaTest::s_pool; +/*static*/ std::unique_ptr GemmaTest::s_model; +/*static*/ gcpp::KVCache GemmaTest::s_kv_cache; + +TEST_F(GemmaTest, Geography) { static const char* kQA[][2] = { {"What is the capital of Hungary?", "Budapest"}, {"How many states does the US have?", "50"}, - {"list me ten biggest cities in the world", "Tokyo"}, }; static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); TestQuestions(kQA, kNum); } -TEST_F(GemmaTest, DISABLED_History) { +TEST_F(GemmaTest, History) { static const char* kQA[][2] = { {"When was the Battle of Hastings?", "1066"}, - {"Who fought at the Battle of Marathon?", "Greek"}, }; static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); TestQuestions(kQA, kNum); } -TEST_F(GemmaTest, DISABLED_Arithmetic) { +TEST_F(GemmaTest, Arithmetic) { static const char* kQA[][2] = { {"what is 13 + 14?", "27"}, - {"what is 7 * 8", "56"}, + {"what is 7 * 8?", "56"}, }; static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); TestQuestions(kQA, kNum); @@ -163,150 +175,44 @@ static const char kGettysburg[] = { "a new birth of freedom; and that this government of the people, by the " "people, for the people, shall not perish from the earth.\n"}; -// The Declaration of Independence. -static const char kDeclaration[] = { - "IN CONGRESS, July 4, 1776.\n\nThe unanimous Declaration of the thirteen " - "united States of America,\n\nWhen in the Course of human events, it " - "becomes necessary for one people to dissolve the political bands which " - "have connected them with another, and to assume among the powers of the " - "earth, the separate and equal station to which the Laws of Nature and of " - "Nature's God entitle them, a decent respect to the opinions of mankind " - "requires that they should declare the causes which impel them to the " - "separation.\n\nWe hold these truths to be self-evident, that all men are " - "created equal, that they are endowed by their Creator with certain " - "unalienable Rights, that among these are Life, Liberty and the pursuit of " - "Happiness.--That to secure these rights, Governments are instituted among " - "Men, deriving their just powers from the consent of the governed, --That " - "whenever any Form of Government becomes destructive of these ends, it is " - "the Right of the People to alter or to abolish it, and to institute new " - "Government, laying its foundation on such principles and organizing its " - "powers in such form, as to them shall seem most likely to effect their " - "Safety and Happiness. Prudence, indeed, will dictate that Governments " - "long established should not be changed for light and transient causes; " - "and accordingly all experience hath shewn, that mankind are more disposed " - "to suffer, while evils are sufferable, than to right themselves by " - "abolishing the forms to which they are accustomed. But when a long train " - "of abuses and usurpations, pursuing invariably the same Object evinces a " - "design to reduce them under absolute Despotism, it is their right, it is " - "their duty, to throw off such Government, and to provide new Guards for " - "their future security.--Such has been the patient sufferance of these " - "Colonies; and such is now the necessity which constrains them to alter " - "their former Systems of Government. The history of the present King of " - "Great Britain is a history of repeated injuries and usurpations, all " - "having in direct object the establishment of an absolute Tyranny over " - "these States. To prove this, let Facts be submitted to a candid " - "world.\n\nHe has refused his Assent to Laws, the most wholesome and " - "necessary for the public good.\nHe has forbidden his Governors to pass " - "Laws of immediate and pressing importance, unless suspended in their " - "operation till his Assent should be obtained; and when so suspended, he " - "has utterly neglected to attend to them.\nHe has refused to pass other " - "Laws for the accommodation of large districts of people, unless those " - "people would relinquish the right of Representation in the Legislature, a " - "right inestimable to them and formidable to tyrants only.\nHe has called " - "together legislative bodies at places unusual, uncomfortable, and distant " - "from the depository of their public Records, for the sole purpose of " - "fatiguing them into compliance with his measures.\nHe has dissolved " - "Representative Houses repeatedly, for opposing with manly firmness his " - "invasions on the rights of the people.\nHe has refused for a long time, " - "after such dissolutions, to cause others to be elected; whereby the " - "Legislative powers, incapable of Annihilation, have returned to the " - "People at large for their exercise; the State remaining in the mean time " - "exposed to all the dangers of invasion from without, and convulsions " - "within.\nHe has endeavoured to prevent the population of these States; " - "for that purpose obstructing the Laws for Naturalization of Foreigners; " - "refusing to pass others to encourage their migrations hither, and raising " - "the conditions of new Appropriations of Lands.\nHe has obstructed the " - "Administration of Justice, by refusing his Assent to Laws for " - "establishing Judiciary powers.\nHe has made Judges dependent on his Will " - "alone, for the tenure of their offices, and the amount and payment of " - "their salaries.\nHe has erected a multitude of New Offices, and sent " - "hither swarms of Officers to harrass our people, and eat out their " - "substance.\nHe has kept among us, in times of peace, Standing Armies " - "without the Consent of our legislatures.\nHe has affected to render the " - "Military independent of and superior to the Civil power.\nHe has combined " - "with others to subject us to a jurisdiction foreign to our constitution, " - "and unacknowledged by our laws; giving his Assent to their Acts of " - "pretended Legislation:\nFor Quartering large bodies of armed troops among " - "us:\nFor protecting them, by a mock Trial, from punishment for any " - "Murders which they should commit on the Inhabitants of these States:\nFor " - "cutting off our Trade with all parts of the world:\nFor imposing Taxes on " - "us without our Consent:\nFor depriving us in many cases, of the benefits " - "of Trial by Jury:\nFor transporting us beyond Seas to be tried for " - "pretended offences\nFor abolishing the free System of English Laws in a " - "neighbouring Province, establishing therein an Arbitrary government, and " - "enlarging its Boundaries so as to render it at once an example and fit " - "instrument for introducing the same absolute rule into these " - "Colonies:\nFor taking away our Charters, abolishing our most valuable " - "Laws, and altering fundamentally the Forms of our Governments:\nFor " - "suspending our own Legislatures, and declaring themselves invested with " - "power to legislate for us in all cases whatsoever.\nHe has abdicated " - "Government here, by declaring us out of his Protection and waging War " - "against us.\nHe has plundered our seas, ravaged our Coasts, burnt our " - "towns, and destroyed the lives of our people.\nHe is at this time " - "transporting large Armies of foreign Mercenaries to compleat the works of " - "death, desolation and tyranny, already begun with circumstances of " - "Cruelty & perfidy scarcely paralleled in the most barbarous ages, and " - "totally unworthy the Head of a civilized nation.\nHe has constrained our " - "fellow Citizens taken Captive on the high Seas to bear Arms against their " - "Country, to become the executioners of their friends and Brethren, or to " - "fall themselves by their Hands.\nHe has excited domestic insurrections " - "amongst us, and has endeavoured to bring on the inhabitants of our " - "frontiers, the merciless Indian Savages, whose known rule of warfare, is " - "an undistinguished destruction of all ages, sexes and conditions.\n\nIn " - "every stage of these Oppressions We have Petitioned for Redress in the " - "most humble terms: Our repeated Petitions have been answered only by " - "repeated injury. A Prince whose character is thus marked by every act " - "which may define a Tyrant, is unfit to be the ruler of a free " - "people.\n\nNor have We been wanting in attentions to our Brittish " - "brethren. We have warned them from time to time of attempts by their " - "legislature to extend an unwarrantable jurisdiction over us. We have " - "reminded them of the circumstances of our emigration and settlement here. " - "We have appealed to their native justice and magnanimity, and we have " - "conjured them by the ties of our common kindred to disavow these " - "usurpations, which, would inevitably interrupt our connections and " - "correspondence. They too have been deaf to the voice of justice and of " - "consanguinity. We must, therefore, acquiesce in the necessity, which " - "denounces our Separation, and hold them, as we hold the rest of mankind, " - "Enemies in War, in Peace Friends.\n\nWe, therefore, the Representatives " - "of the united States of America, in General Congress, Assembled, " - "appealing to the Supreme Judge of the world for the rectitude of our " - "intentions, do, in the Name, and by Authority of the good People of these " - "Colonies, solemnly publish and declare, That these United Colonies are, " - "and of Right ought to be Free and Independent States; that they are " - "Absolved from all Allegiance to the British Crown, and that all political " - "connection between them and the State of Great Britain, is and ought to " - "be totally dissolved; and that as Free and Independent States, they have " - "full Power to levy War, conclude Peace, contract Alliances, establish " - "Commerce, and to do all other Acts and Things which Independent States " - "may of right do. And for the support of this Declaration, with a firm " - "reliance on the protection of divine Providence, we mutually pledge to " - "each other our Lives, our Fortunes and our sacred Honor.\n"}; - -TEST_F(GemmaTest, DISABLED_CrossEntropySmall) { +TEST_F(GemmaTest, CrossEntropySmall) { + if (!s_model) return; static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; float entropy = GemmaCrossEntropy(kSmall); - std::cout << "per-byte entropy: " << entropy << "\n"; - EXPECT_LT(entropy, 1.6f); + fprintf(stderr, "per-byte entropy: %f\n", entropy); + // Note that entropy is 3x higher for the 7b-it model. + EXPECT_LT(entropy, 1.7f); } -TEST_F(GemmaTest, DISABLED_CrossEntropyJingleBells) { +TEST_F(GemmaTest, CrossEntropyJingleBells) { + if (!s_model) return; float entropy = GemmaCrossEntropy(kJingleBells); - std::cout << "per-byte entropy: " << entropy << "\n"; - EXPECT_LT(entropy, 2.3f); + fprintf(stderr, "per-byte entropy: %f\n", entropy); + EXPECT_LT(entropy, 1.7f); } -TEST_F(GemmaTest, DISABLED_CrossEntropyGettysburg) { +TEST_F(GemmaTest, CrossEntropyGettysburg) { + if (!s_model) return; float entropy = GemmaCrossEntropy(kGettysburg); - std::cout << "per-byte entropy: " << entropy << "\n"; + fprintf(stderr, "per-byte entropy: %f\n", entropy); EXPECT_LT(entropy, 1.2f); } -TEST_F(GemmaTest, DISABLED_CrossEntropyDeclaration) { - float entropy = GemmaCrossEntropy(kDeclaration); - std::cout << "per-byte entropy: " << entropy << "\n"; - EXPECT_LT(entropy, 1.0f); -} - } // namespace } // namespace gcpp + +int main(int argc, char** argv) { + { + // Placeholder for internal init, do not modify. + } + + // For later use by SetUp. + gcpp::s_argc = argc; + gcpp::s_argv = argv; + + // Probably should be called before SetUpTestSuite. + testing::InitGoogleTest(&gcpp::s_argc, argv); + + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/gemma/ops.h b/gemma/ops.h index eea838d..9089745 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -1340,11 +1340,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( std::array top_k{}; // sorted from highest [0], to lowest [k-1] std::array indices{}; for (size_t i = 0; i < vocab_size; ++i) { - if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast(i))) { + if (probabilities[i] < top_k[k - 1] && + (!accept_token || accept_token(StaticCast(i)))) { continue; } for (size_t j = 0; j < k; ++j) { - if (probabilities[i] > top_k[j] && accept_token(StaticCast(i))) { + if (probabilities[i] > top_k[j] && + (!accept_token || accept_token(StaticCast(i)))) { // shift elements by 1, insert the new value, move on to next value for (size_t idx = k - 1; idx > j; --idx) { top_k[idx] = top_k[idx - 1]; diff --git a/gemma/run.cc b/gemma/run.cc index 3538184..ebd131e 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -284,10 +284,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { std::cout << "\n" << instructions << "\n"; } - ReplGemma( - model, loader.ModelTrainingType(), kv_cache, pool, inference, - app.verbosity, - /*accept_token=*/[](int) { return true; }, app.eot_line); + ReplGemma(model, loader.ModelTrainingType(), kv_cache, pool, inference, + app.verbosity, AcceptFunc(), app.eot_line); } } // namespace gcpp diff --git a/gemma/run_mmlu.cc b/gemma/run_mmlu.cc index 9f6252b..6de8b37 100644 --- a/gemma/run_mmlu.cc +++ b/gemma/run_mmlu.cc @@ -98,8 +98,8 @@ void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, return true; }; - auto accept_token = [¤t_pos, &prompt_size, - &accept_token_set](int token) { + const AcceptFunc accept_token = [¤t_pos, &prompt_size, + &accept_token_set](int token) { // i.e. we have no constraints on accepted tokens if (accept_token_set.empty()) { return true; diff --git a/util/app.h b/util/app.h index 3bfeb00..e4fd15d 100644 --- a/util/app.h +++ b/util/app.h @@ -18,6 +18,8 @@ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ +#include + #include "hwy/contrib/thread_pool/thread_pool.h" #if HWY_OS_LINUX #include @@ -242,6 +244,12 @@ static inline Gemma CreateGemma(const LoaderArgs& loader, loader.WeightType(), pool); } +static inline std::unique_ptr AllocateGemma(const LoaderArgs& loader, + hwy::ThreadPool& pool) { + return std::make_unique(loader.tokenizer, loader.weights, + loader.ModelType(), loader.WeightType(), pool); +} + struct InferenceArgs : public ArgsBase { InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }