From 3d1625d8c578605e3a618715f858714df3471867 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 21 Nov 2024 05:27:02 -0800 Subject: [PATCH 1/2] Improved consistency of compressor API, and added a universal method with a target type arg. Moved configs pybind up to root level. PiperOrigin-RevId: 698743417 --- compression/compress-inl.h | 3 +- compression/python/BUILD.bazel | 2 + compression/python/compression_clif_aux.cc | 42 +++++- compression/python/compression_clif_aux.h | 6 +- compression/python/compression_extension.cc | 14 +- compression/python/compression_test.py | 9 +- gemma/configs.cc | 3 + gemma/configs_test.cc | 1 - gemma/python/BUILD.bazel | 17 --- gemma/python/configs.cc | 156 -------------------- 10 files changed, 63 insertions(+), 190 deletions(-) delete mode 100644 gemma/python/BUILD.bazel delete mode 100644 gemma/python/configs.cc diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 7fd097c..be4c5af 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -690,12 +690,13 @@ class Compressor { } } - void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) { + BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) { const BlobError err = writer_.WriteAll(pool, blob_filename); if (err != 0) { fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename.path.c_str(), err); } + return err; } private: diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index 6b451bd..f12d8be 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -28,6 +28,7 @@ pybind_extension( deps = [ ":compression_clif_aux", "@abseil-cpp//absl/types:span", + "//compression:sfp", ], ) @@ -38,6 +39,7 @@ py_test( deps = [ ":compression", "//testing/pybase", + "//python:configs", "//third_party/py/numpy", ], ) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index ba91781..e313bbe 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -31,7 +31,9 @@ class WriterInterface { public: virtual ~WriterInterface() = default; - virtual void Insert(std::string name, absl::Span weights) = 0; + virtual void Insert(std::string name, absl::Span weights, + Type type) = 0; + virtual void InsertSfp(std::string name, absl::Span weights) = 0; virtual void InsertNUQ(std::string name, absl::Span weights) = 0; virtual void InsertBfloat16(std::string name, absl::Span weights) = 0; @@ -39,7 +41,7 @@ class WriterInterface { absl::Span weights) = 0; virtual void AddScales(const std::vector& scales) = 0; - virtual void Write(std::string path) = 0; + virtual int Write(std::string path) = 0; }; } // namespace gcpp @@ -67,7 +69,27 @@ class SbsWriterImpl : public WriterInterface { public: SbsWriterImpl() : pool_(0), compressor_(pool_) {} - void Insert(std::string name, absl::Span weights) override { + void Insert(std::string name, absl::Span weights, + Type type) override { + switch (type) { + case Type::kSFP: + AllocateAndCompress(name, weights); + break; + case Type::kNUQ: + AllocateAndCompress(name, weights); + break; + case Type::kBF16: + AllocateAndCompress(name, weights); + break; + case Type::kF32: + AllocateAndCompress(name, weights); + break; + default: + HWY_ABORT("Unsupported type"); + } + } + + void InsertSfp(std::string name, absl::Span weights) override { AllocateAndCompress(name, weights); } @@ -90,8 +112,8 @@ class SbsWriterImpl : public WriterInterface { compressor_.AddScales(scales_.data(), scales_.size()); } - void Write(std::string path) override { - compressor_.WriteAll(pool_, gcpp::Path(path)); + int Write(std::string path) override { + return compressor_.WriteAll(pool_, gcpp::Path(path)); } hwy::ThreadPool pool_; @@ -115,8 +137,12 @@ HWY_EXPORT(NewSbsWriter); SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {} SbsWriter::~SbsWriter() = default; -void SbsWriter::Insert(std::string name, absl::Span weights) { - impl_->Insert(name, weights); +void SbsWriter::Insert(std::string name, absl::Span weights, + Type type) { + impl_->Insert(name, weights, type); +} +void SbsWriter::InsertSfp(std::string name, absl::Span weights) { + impl_->InsertSfp(name, weights); } void SbsWriter::InsertNUQ(std::string name, absl::Span weights) { impl_->InsertNUQ(name, weights); @@ -132,7 +158,7 @@ void SbsWriter::InsertFloat(std::string name, absl::Span weights) { void SbsWriter::AddScales(const std::vector& scales) { impl_->AddScales(scales); } -void SbsWriter::Write(std::string path) { impl_->Write(path); } +int SbsWriter::Write(std::string path) { return impl_->Write(path); } } // namespace gcpp #endif // HWY_ONCE diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index fd4efc8..cd2e4f1 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -6,6 +6,7 @@ #include #include "absl/types/span.h" +#include "compression/shared.h" namespace gcpp { @@ -16,13 +17,14 @@ class SbsWriter { SbsWriter(); ~SbsWriter(); - void Insert(std::string name, absl::Span weights); + void Insert(std::string name, absl::Span weights, Type type); + void InsertSfp(std::string name, absl::Span weights); void InsertNUQ(std::string name, absl::Span weights); void InsertBfloat16(std::string name, absl::Span weights); void InsertFloat(std::string name, absl::Span weights); void AddScales(const std::vector& scales); - void Write(std::string path); + int Write(std::string path); private: // Isolates Highway-dispatched types and other internals from CLIF. diff --git a/compression/python/compression_extension.cc b/compression/python/compression_extension.cc index c2916a8..c56f263 100644 --- a/compression/python/compression_extension.cc +++ b/compression/python/compression_extension.cc @@ -1,11 +1,11 @@ #include -#include #include #include #include "absl/types/span.h" #include "compression/python/compression_clif_aux.h" +#include "compression/shared.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -22,6 +22,15 @@ void wrap_span(SbsWriter& writer, std::string name, py::array_t data) { } std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size())); } +template +void wrap_span_typed(SbsWriter& writer, std::string name, + py::array_t data, gcpp::Type type) { + if (data.ndim() != 1 || data.strides(0) != sizeof(float)) { + throw std::domain_error("Input array must be 1D and densely packed."); + } + std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()), + type); +} } // namespace PYBIND11_MODULE(compression, m) { @@ -29,7 +38,8 @@ PYBIND11_MODULE(compression, m) { .def(py::init<>()) // NOTE: Individual compression backends may impose constraints on the // array length, such as a minimum of (say) 32 elements. - .def("insert", wrap_span<&SbsWriter::Insert>) + .def("insert", wrap_span_typed<&SbsWriter::Insert>) + .def("insert_sfp", wrap_span<&SbsWriter::InsertSfp>) .def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>) .def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>) .def("insert_float", wrap_span<&SbsWriter::InsertFloat>) diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index 7b7ff12..e25f06b 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -4,6 +4,7 @@ import numpy as np import unittest from compression.python import compression +from gemma.python import configs class CompressionTest(unittest.TestCase): @@ -13,9 +14,11 @@ class CompressionTest(unittest.TestCase): writer = compression.SbsWriter() writer.insert( - "foo", np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32) + "foo", + np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32), + configs.Type.kSFP, ) - writer.insert( + writer.insert_sfp( "bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32) ) writer.insert_nuq( @@ -27,7 +30,7 @@ class CompressionTest(unittest.TestCase): writer.insert_float( "quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32) ) - writer.write(temp_file.full_path) + self.assertEqual(writer.write(temp_file.full_path), 0) if __name__ == "__main__": diff --git a/gemma/configs.cc b/gemma/configs.cc index 7724c59..7a792cf 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -49,6 +49,7 @@ static ModelConfig ConfigGemma2_27B() { .heads = 32, .kv_heads = 16, .qkv_dim = 128, + .optimized_gating = false, .post_norm = PostNormType::Scale}; config.layer_configs = {46, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); @@ -70,6 +71,7 @@ static ModelConfig ConfigGemma2_9B() { .heads = 16, .kv_heads = 8, .qkv_dim = 256, + .optimized_gating = false, .post_norm = PostNormType::Scale}; config.layer_configs = {42, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); @@ -91,6 +93,7 @@ static ModelConfig ConfigGemma2_2B() { .heads = 8, .kv_heads = 4, .qkv_dim = 256, + .optimized_gating = false, .post_norm = PostNormType::Scale}; config.layer_configs = {26, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 8128baf..fa8d870 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -374,7 +374,6 @@ void AssertMatch(const ModelConfig& config) { } ASSERT_EQ(TConfig::kVocabSize, config.vocab_size); ASSERT_EQ(TConfig::kSeqLen, config.seq_len); - // ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value. ASSERT_EQ(TConfig::kAttCap, config.att_cap); ASSERT_EQ(TConfig::kFinalCap, config.final_cap); ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe); diff --git a/gemma/python/BUILD.bazel b/gemma/python/BUILD.bazel deleted file mode 100644 index d6b09b9..0000000 --- a/gemma/python/BUILD.bazel +++ /dev/null @@ -1,17 +0,0 @@ -load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") - -package( - default_applicable_licenses = [ - "//:license", # Placeholder comment, do not modify - ], - default_visibility = ["//visibility:public"], -) - -pybind_extension( - name = "configs", - srcs = ["configs.cc"], - deps = [ - "//:common", - "//compression:sfp", - ], -) diff --git a/gemma/python/configs.cc b/gemma/python/configs.cc deleted file mode 100644 index 8c37840..0000000 --- a/gemma/python/configs.cc +++ /dev/null @@ -1,156 +0,0 @@ -#include "gemma/configs.h" - -#include -#include - -#include "compression/shared.h" -#include "gemma/tensor_index.h" -#include "pybind11/cast.h" - -using gcpp::ActivationType; -using gcpp::LayerAttentionType; -using gcpp::LayerConfig; -using gcpp::Model; -using gcpp::ModelConfig; -using gcpp::ModelTraining; -using gcpp::PostNormType; -using gcpp::PostQKType; -using gcpp::QueryScaleType; -using gcpp::ResidualType; -using gcpp::TensorIndex; -using gcpp::TensorInfo; -using gcpp::Type; - -namespace pybind11 { - -PYBIND11_MODULE(configs, py_module) { - enum_(py_module, "ModelTraining") - .value("GEMMA_IT", ModelTraining::GEMMA_IT) - .value("GEMMA_PT", ModelTraining::GEMMA_PT) - .value("PALIGEMMA", ModelTraining::PALIGEMMA); - - enum_(py_module, "Type") - .value("kUnknown", Type::kUnknown) - .value("kF32", Type::kF32) - .value("kBF16", Type::kBF16) - .value("kSFP", Type::kSFP) - .value("kNUQ", Type::kNUQ) - .value("kF64", Type::kF64) - .value("kC64", Type::kC64) - .value("kU128", Type::kU128); - - enum_(py_module, "LayerAttentionType") - .value("kGemma", LayerAttentionType::kGemma) - .value("kGriffinRecurrentBlock", - LayerAttentionType::kGriffinRecurrentBlock) - .value("kVit", LayerAttentionType::kVit); - - enum_(py_module, "PostNormType") - .value("NoPostNorm", PostNormType::None) - .value("Scale", PostNormType::Scale); - - enum_(py_module, "PostQKType") - .value("Rope", PostQKType::Rope) - .value("HalfRope", PostQKType::HalfRope); - - enum_(py_module, "ActivationType") - .value("Gelu", ActivationType::Gelu); - - enum_(py_module, "QueryScaleType") - .value("SqrtKeySize", QueryScaleType::SqrtKeySize) - .value("SqrtModelDimDivNumHeads", - QueryScaleType::SqrtModelDimDivNumHeads); - - enum_(py_module, "ResidualType") - .value("Add", ResidualType::Add); - - enum_(py_module, "Model") - .value("UNKNOWN", Model::UNKNOWN) - .value("GEMMA_2B", Model::GEMMA_2B) - .value("GEMMA_7B", Model::GEMMA_7B) - .value("GEMMA2_9B", Model::GEMMA2_9B) - .value("GEMMA2_27B", Model::GEMMA2_27B) - .value("GRIFFIN_2B", Model::GRIFFIN_2B) - .value("GEMMA_TINY", Model::GEMMA_TINY) - .value("GEMMA2_2B", Model::GEMMA2_2B) - .value("PALIGEMMA_224", Model::PALIGEMMA_224); - - class_(py_module, "TensorInfo") - .def(init()) - .def_readwrite("name", &TensorInfo::name) - .def_readwrite("source_names", &TensorInfo::source_names) - .def_readwrite("preshape", &TensorInfo::preshape) - .def_readwrite("axes", &TensorInfo::axes) - .def_readwrite("shape", &TensorInfo::shape) - .def_readwrite("concat_names", &TensorInfo::concat_names) - .def_readwrite("concat_axis", &TensorInfo::concat_axis) - .def_readwrite("min_size", &TensorInfo::min_size) - .def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus) - .def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims); - - class_(py_module, "TensorIndex") - .def(init()) - .def("get_tensor_info", &TensorIndex::GetTensorInfo, arg("path")); - - class_(py_module, "LayerConfig") - .def(init()) - .def_readwrite("model_dim", &LayerConfig::model_dim) - .def_readwrite("griffin_dim", &LayerConfig::griffin_dim) - .def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim) - .def_readwrite("heads", &LayerConfig::heads) - .def_readwrite("kv_heads", &LayerConfig::kv_heads) - .def_readwrite("qkv_dim", &LayerConfig::qkv_dim) - .def_readwrite("conv1d_width", &LayerConfig::conv1d_width) - .def_readwrite("ff_biases", &LayerConfig::ff_biases) - .def_readwrite("softmax_attn_output_biases", - &LayerConfig::softmax_attn_output_biases) - .def_readwrite("optimized_gating", &LayerConfig::optimized_gating) - .def_readwrite("post_norm", &LayerConfig::post_norm) - .def_readwrite("type", &LayerConfig::type) - .def_readwrite("activation", &LayerConfig::activation) - .def_readwrite("post_qk", &LayerConfig::post_qk); - - class_(py_module, "ModelConfig") - .def(init()) - .def_readwrite("model_name", &ModelConfig::model_name) - .def_readwrite("model", &ModelConfig::model) - .def_readwrite("training", &ModelConfig::training) - .def_readwrite("weight", &ModelConfig::weight) - .def_readwrite("num_layers", &ModelConfig::num_layers) - .def_readwrite("model_dim", &ModelConfig::model_dim) - .def_readwrite("vit_model_dim", &ModelConfig::vit_model_dim) - .def_readwrite("vocab_size", &ModelConfig::vocab_size) - .def_readwrite("seq_len", &ModelConfig::seq_len) - .def_readwrite("vit_seq_len", &ModelConfig::vit_seq_len) - .def_readwrite("num_tensor_scales", &ModelConfig::num_tensor_scales) - .def_readwrite("num_vit_scales", &ModelConfig::num_vit_scales) - .def_readwrite("att_cap", &ModelConfig::att_cap) - .def_readwrite("final_cap", &ModelConfig::final_cap) - .def_readwrite("absolute_pe", &ModelConfig::absolute_pe) - .def_readwrite("use_local_attention", &ModelConfig::use_local_attention) - .def_readwrite("query_scale", &ModelConfig::query_scale) - .def_readwrite("layer_configs", &ModelConfig::layer_configs) - .def_readwrite("attention_window_sizes", - &ModelConfig::attention_window_sizes) - .def_readwrite("vit_layer_configs", &ModelConfig::vit_layer_configs) - .def_readwrite("scale_names", &ModelConfig::scale_names) - .def_readwrite("norm_num_groups", &ModelConfig::norm_num_groups) - .def_readwrite("model_family_version", &ModelConfig::model_family_version) - .def_readwrite("patch_width", &ModelConfig::patch_width) - .def_readwrite("image_size", &ModelConfig::image_size) - .def("add_layer_config", &ModelConfig::AddLayerConfig, - arg("layer_config")) - .def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("partial"), - arg("debug")); - - // Returns the config for the given model. - py_module.def("config_from_model", &gcpp::ConfigFromModel, arg("model")); - - // Returns the model for the given config, if it matches any standard model. - py_module.def("model_from_config", &gcpp::ModelFromConfig, arg("config")); - - // Returns the sub-config for the ViT model of the PaliGemma model. - py_module.def("vit_config", &gcpp::VitConfig, arg("config")); -} - -} // namespace pybind11 From 109a4d9f85d6e70ee320914629f47ffd2fd30661 Mon Sep 17 00:00:00 2001 From: Stanko Novakovic Date: Thu, 21 Nov 2024 10:59:16 -0800 Subject: [PATCH 2/2] Add a simple benchmark for batching. This is a simple Gemma benchmark with a fixed batch size of 32. PiperOrigin-RevId: 698843573 --- BUILD.bazel | 20 +++++ CMakeLists.txt | 3 + evals/gemma_batch_bench.cc | 146 +++++++++++++++++++++++++++++++++++++ 3 files changed, 169 insertions(+) create mode 100644 evals/gemma_batch_bench.cc diff --git a/BUILD.bazel b/BUILD.bazel index 669caf8..204f24e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -388,6 +388,26 @@ cc_test( ], ) +cc_test( + name = "gemma_batch_bench", + srcs = ["evals/gemma_batch_bench.cc"], + # Requires model files + tags = [ + "local", + "manual", + "no_tap", + ], + deps = [ + ":benchmark_helper", + ":common", + ":gemma_lib", + ":tokenizer", + "@googletest//:gtest_main", + "@highway//:hwy", + "@highway//:hwy_test_util", + ], +) + cc_binary( name = "gemma", srcs = ["gemma/run.cc"], diff --git a/CMakeLists.txt b/CMakeLists.txt index 4d03da0..0b16c65 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -135,6 +135,9 @@ install(TARGETS gemma DESTINATION bin) add_executable(single_benchmark evals/benchmark.cc) target_link_libraries(single_benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) +add_executable(gemma_batch_bench evals/gemma_batch_bench.cc) +target_link_libraries(gemma_batch_bench libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) + add_executable(benchmarks evals/benchmarks.cc) target_link_libraries(benchmarks libgemma hwy hwy_contrib nlohmann_json::nlohmann_json benchmark) diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc new file mode 100644 index 0000000..44b803f --- /dev/null +++ b/evals/gemma_batch_bench.cc @@ -0,0 +1,146 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/gemma.h" + +#include + +#include +#include + +#include "evals/benchmark_helper.h" +#include "gemma/common.h" +#include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" + +// This test can be run manually with the downloaded gemma weights. +// To run the test, pass the following flags: +// --model --tokenizer --weights +// It should pass for the following models: +// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it, +// Gemma2: gemma2-2b-it, 9b-it, 27b-it, + +namespace gcpp { +namespace { + +// Shared state. Requires argc/argv, so construct in main and use the same raw +// pointer approach as in benchmarks.cc. Note that the style guide forbids +// non-local static variables with dtors. +GemmaEnv* s_env = nullptr; + +class GemmaTest : public ::testing::Test { + protected: + std::vector BatchGemmaReply( + const std::vector& inputs) { + s_env->SetMaxGeneratedTokens(64); + s_env->MutableConfig().temperature = 0.0f; // deterministic + s_env->MutableConfig().verbosity = 5; + std::vector replies; + // Using the turn structure worsens results sometimes. + // However, some models need the turn structure to work. + // It would be good to make these tests more consistent. + if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || + s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { + for (QueryResult result : s_env->BatchQueryModel(inputs)) { + replies.push_back(result.response); + } + return replies; + } + // Otherwise, do not use turn structure. + std::vector> prompts_vector; + prompts_vector.reserve(inputs.size()); + for (const auto& input_string : inputs) { + prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); + } + std::vector prompt_spans; + for (const auto& prompt : prompts_vector) { + prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); + } + QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); + for (const QueryResult& result : s_env->BatchQueryModel(prompts)) { + replies.push_back(result.response); + } + return replies; + } + + void GenerateTokens(std::vector &kQA, size_t num_questions) { + ASSERT_NE(s_env->GetModel(), nullptr); + + std::vector inputs; + for (size_t i = 0; i < num_questions; ++i) { + inputs.push_back(kQA[i]); + } + std::vector responses = BatchGemmaReply(inputs); + for (size_t i = 0; i < num_questions; ++i) { + std::string response = responses.at(i); + fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str()); + } + } +}; + +TEST_F(GemmaTest, RandomQuestionsBatched) { + s_env->MutableConfig().decode_qbatch_size = 3; + s_env->MutableConfig().verbosity = 5; + + static std::vector kQA = { + {"Write me a poem about Australia?"}, + {"What's the history of Denmark?"}, + {"Write me a comedy story about the USA."}, + {"Teach me about GPU programming."}, + {"Write me a story about the moon."}, + {"Write me a story about the universe."}, + {"Write a poem about planet earth."}, + {"Tell me more about olympic sports."}, + {"How would you describe Washington State?"}, + {"Write me a story about Silicon Valley."}, + {"Write me about your best friend."}, + {"How would you describe a unicorn?"}, + {"Tell me about world war history."}, + {"Tell me about Google."}, + {"Explain to me how to use Google Maps."}, + {"Explain to me how AI works."}, + {"Write me a poem about France."}, + {"What's the history of Great Britain?"}, + {"Write me a comedy story about Florida."}, + {"Teach me about dynamic programming."}, + {"Write me a story about Jupiter."}, + {"Write me a story about space ships."}, + {"Write a poem about some random planet."}, + {"Tell me more about team sports."}, + {"How would you describe Michigan State?"}, + {"Write me a story about Europe."}, + {"Write me about your best colleague."}, + {"How would you describe a horse?"}, + {"Tell me about World War 2."}, + {"Please share some good cooking tips."}, + {"Tell me about space travel."}, + {"Explain to me how electric cars work."}, + }; + static const size_t kNum = kQA.size(); + GenerateTokens(kQA, kNum); +} +} // namespace +} // namespace gcpp + +int main(int argc, char** argv) { + gcpp::GemmaEnv env(argc, argv); + gcpp::s_env = &env; + + testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} + +