From 1982a6ba00a2f94fdf7178dfa0e8946acaea7c7b Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Tue, 30 Jul 2024 20:24:21 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 657831926 --- BUILD.bazel | 3 +++ CMakeLists.txt | 3 +++ gemma/common.cc | 15 +++++++++------ gemma/common.h | 10 ++++++++++ gemma/configs.h | 22 ++++++++++++++++++++++ gemma/instantiations/gemma2_2b_bf16.cc | 21 +++++++++++++++++++++ gemma/instantiations/gemma2_2b_f32.cc | 21 +++++++++++++++++++++ gemma/instantiations/gemma2_2b_sfp.cc | 21 +++++++++++++++++++++ 8 files changed, 110 insertions(+), 6 deletions(-) create mode 100644 gemma/instantiations/gemma2_2b_bf16.cc create mode 100644 gemma/instantiations/gemma2_2b_f32.cc create mode 100644 gemma/instantiations/gemma2_2b_sfp.cc diff --git a/BUILD.bazel b/BUILD.bazel index 25df65c..4dd4697 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -170,6 +170,9 @@ cc_library( "gemma/instantiations/gr2b_bf16.cc", "gemma/instantiations/gr2b_f32.cc", "gemma/instantiations/gr2b_sfp.cc", + "gemma/instantiations/gemma2_2b_bf16.cc", + "gemma/instantiations/gemma2_2b_f32.cc", + "gemma/instantiations/gemma2_2b_sfp.cc", ], hdrs = [ "gemma/activations.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index 38dd73a..6bc3d86 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,6 +91,9 @@ set(SOURCES gemma/instantiations/tiny_bf16.cc gemma/instantiations/tiny_f32.cc gemma/instantiations/tiny_sfp.cc + gemma/instantiations/gemma2_2b_bf16.cc + gemma/instantiations/gemma2_2b_f32.cc + gemma/instantiations/gemma2_2b_sfp.cc gemma/kv_cache.cc gemma/kv_cache.h gemma/tokenizer.cc diff --git a/gemma/common.cc b/gemma/common.cc index 7fa60c9..1d50743 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -29,12 +29,13 @@ namespace gcpp { constexpr const char* kModelFlags[] = { - "2b-pt", "2b-it", // Gemma 2B - "7b-pt", "7b-it", // Gemma 7B - "9b-pt", "9b-it", // Gemma 9B - "27b-pt", "27b-it", // Gemma 27B - "gr2b-pt", "gr2b-it", // RecurrentGemma - "tiny", // Gemma Tiny (mostly for debugging) + "2b-pt", "2b-it", // Gemma 2B + "7b-pt", "7b-it", // Gemma 7B + "9b-pt", "9b-it", // Gemma 9B + "27b-pt", "27b-it", // Gemma 27B + "gr2b-pt", "gr2b-it", // RecurrentGemma + "tiny", // Gemma Tiny (mostly for debugging) + "gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B }; constexpr Model kModelTypes[] = { Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B @@ -43,6 +44,7 @@ constexpr Model kModelTypes[] = { Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma Model::GEMMA_TINY, // Gemma Tiny + Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B }; constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B @@ -51,6 +53,7 @@ constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma ModelTraining::GEMMA_IT, // Gemma Tiny + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B2 }; constexpr size_t kNumModelFlags = diff --git a/gemma/common.h b/gemma/common.h index 7471ceb..f0498ba 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -45,6 +45,7 @@ enum class Model { GEMMA_27B, GRIFFIN_2B, GEMMA_TINY, + GEMMA2_2B, }; // Instruction-tuned models require extra 'turn structure' tokens in prompts. @@ -99,6 +100,9 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) { return FuncT>()(std::forward(args)...); case Model::GRIFFIN_2B: return FuncT>()(std::forward(args)...); + case Model::GEMMA2_2B: + return FuncT>()(std::forward(args)...); + default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } @@ -142,6 +146,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \ GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \ GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \ + GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \ static_assert(true, "Allow trailing ;") // Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float), @@ -178,6 +183,11 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, ARGS; \ break; \ } \ + case Model::GEMMA2_2B: { \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ + ARGS; \ + break; \ + } \ default: \ HWY_ABORT("Model type %d unknown.", static_cast(MODEL)); \ } diff --git a/gemma/configs.h b/gemma/configs.h index efe5476..be995f9 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -253,6 +253,28 @@ struct ConfigGemma2B : public ConfigBaseGemmaV1 { static constexpr bool kAbsolutePE = false; }; +template +struct ConfigGemma2_2B : public ConfigBaseGemmaV2 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = 8192; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = + FixedLayerConfig<26>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen}); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 2304; + static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216 + static constexpr int kHeads = 8; + static constexpr int kKVHeads = 4; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; +}; + template struct ConfigGemmaTiny : public ConfigNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig diff --git a/gemma/instantiations/gemma2_2b_bf16.cc b/gemma/instantiations/gemma2_2b_bf16.cc new file mode 100644 index 0000000..d817137 --- /dev/null +++ b/gemma/instantiations/gemma2_2b_bf16.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/gemma2_2b_bf16.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2_2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/gemma2_2b_f32.cc b/gemma/instantiations/gemma2_2b_f32.cc new file mode 100644 index 0000000..c2f52a1 --- /dev/null +++ b/gemma/instantiations/gemma2_2b_f32.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/gemma2_2b_f32.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2_2B +#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/gemma2_2b_sfp.cc b/gemma/instantiations/gemma2_2b_sfp.cc new file mode 100644 index 0000000..1122ba9 --- /dev/null +++ b/gemma/instantiations/gemma2_2b_sfp.cc @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "gemma/instantiations/gemma2_2b_sfp.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#define GEMMA_CONFIG ConfigGemma2_2B +#include "gemma/gemma-inl.h"