Internal change

PiperOrigin-RevId: 657831926
This commit is contained in:
Phil Culliton 2024-07-30 20:24:21 -07:00 committed by Copybara-Service
parent a24eda8d02
commit 1982a6ba00
8 changed files with 110 additions and 6 deletions

View File

@ -170,6 +170,9 @@ cc_library(
"gemma/instantiations/gr2b_bf16.cc",
"gemma/instantiations/gr2b_f32.cc",
"gemma/instantiations/gr2b_sfp.cc",
"gemma/instantiations/gemma2_2b_bf16.cc",
"gemma/instantiations/gemma2_2b_f32.cc",
"gemma/instantiations/gemma2_2b_sfp.cc",
],
hdrs = [
"gemma/activations.h",

View File

@ -91,6 +91,9 @@ set(SOURCES
gemma/instantiations/tiny_bf16.cc
gemma/instantiations/tiny_f32.cc
gemma/instantiations/tiny_sfp.cc
gemma/instantiations/gemma2_2b_bf16.cc
gemma/instantiations/gemma2_2b_f32.cc
gemma/instantiations/gemma2_2b_sfp.cc
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/tokenizer.cc

View File

@ -35,6 +35,7 @@ constexpr const char* kModelFlags[] = {
"27b-pt", "27b-it", // Gemma 27B
"gr2b-pt", "gr2b-it", // RecurrentGemma
"tiny", // Gemma Tiny (mostly for debugging)
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
@ -43,6 +44,7 @@ constexpr Model kModelTypes[] = {
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
Model::GEMMA_TINY, // Gemma Tiny
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
@ -51,6 +53,7 @@ constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
ModelTraining::GEMMA_IT, // Gemma Tiny
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B2
};
constexpr size_t kNumModelFlags =

View File

@ -45,6 +45,7 @@ enum class Model {
GEMMA_27B,
GRIFFIN_2B,
GEMMA_TINY,
GEMMA2_2B,
};
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
@ -99,6 +100,9 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GRIFFIN_2B:
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA2_2B:
return FuncT<ConfigGemma2_2B<TWeight>>()(std::forward<TArgs>(args)...);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
@ -142,6 +146,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
static_assert(true, "Allow trailing ;")
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
@ -178,6 +183,11 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
ARGS; \
break; \
} \
case Model::GEMMA2_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2_2B<TWEIGHT>>) \
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
}

View File

@ -253,6 +253,28 @@ struct ConfigGemma2B : public ConfigBaseGemmaV1 {
static constexpr bool kAbsolutePE = false;
};
template <typename TWeight>
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
FixedLayerConfig<26>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2304;
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 4;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
template <typename TWeight>
struct ConfigGemmaTiny : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig

View File

@ -0,0 +1,21 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gemma2_2b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_2B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -0,0 +1,21 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gemma2_2b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_2B<float>
#include "gemma/gemma-inl.h"

View File

@ -0,0 +1,21 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gemma2_2b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_2B<SfpStream>
#include "gemma/gemma-inl.h"