mirror of https://github.com/google/gemma.cpp.git
parent
a24eda8d02
commit
1982a6ba00
|
|
@ -170,6 +170,9 @@ cc_library(
|
|||
"gemma/instantiations/gr2b_bf16.cc",
|
||||
"gemma/instantiations/gr2b_f32.cc",
|
||||
"gemma/instantiations/gr2b_sfp.cc",
|
||||
"gemma/instantiations/gemma2_2b_bf16.cc",
|
||||
"gemma/instantiations/gemma2_2b_f32.cc",
|
||||
"gemma/instantiations/gemma2_2b_sfp.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
|
|
|
|||
|
|
@ -91,6 +91,9 @@ set(SOURCES
|
|||
gemma/instantiations/tiny_bf16.cc
|
||||
gemma/instantiations/tiny_f32.cc
|
||||
gemma/instantiations/tiny_sfp.cc
|
||||
gemma/instantiations/gemma2_2b_bf16.cc
|
||||
gemma/instantiations/gemma2_2b_f32.cc
|
||||
gemma/instantiations/gemma2_2b_sfp.cc
|
||||
gemma/kv_cache.cc
|
||||
gemma/kv_cache.h
|
||||
gemma/tokenizer.cc
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ constexpr const char* kModelFlags[] = {
|
|||
"27b-pt", "27b-it", // Gemma 27B
|
||||
"gr2b-pt", "gr2b-it", // RecurrentGemma
|
||||
"tiny", // Gemma Tiny (mostly for debugging)
|
||||
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
|
||||
};
|
||||
constexpr Model kModelTypes[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
||||
|
|
@ -43,6 +44,7 @@ constexpr Model kModelTypes[] = {
|
|||
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
|
||||
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
|
||||
Model::GEMMA_TINY, // Gemma Tiny
|
||||
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
|
||||
};
|
||||
constexpr ModelTraining kModelTraining[] = {
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
|
||||
|
|
@ -51,6 +53,7 @@ constexpr ModelTraining kModelTraining[] = {
|
|||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
|
||||
ModelTraining::GEMMA_IT, // Gemma Tiny
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B2
|
||||
};
|
||||
|
||||
constexpr size_t kNumModelFlags =
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ enum class Model {
|
|||
GEMMA_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY,
|
||||
GEMMA2_2B,
|
||||
};
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
|
|
@ -99,6 +100,9 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
|
|||
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GRIFFIN_2B:
|
||||
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA2_2B:
|
||||
return FuncT<ConfigGemma2_2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
|
|
@ -142,6 +146,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
|||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
|
||||
static_assert(true, "Allow trailing ;")
|
||||
|
||||
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
|
||||
|
|
@ -178,6 +183,11 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
|||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA2_2B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2_2B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
||||
}
|
||||
|
|
|
|||
|
|
@ -253,6 +253,28 @@ struct ConfigGemma2B : public ConfigBaseGemmaV1 {
|
|||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
|
||||
FixedLayerConfig<26>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2304;
|
||||
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 4;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemmaTiny : public ConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/gemma2_2b_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_2B<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/gemma2_2b_f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_2B<float>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/gemma2_2b_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_2B<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
Loading…
Reference in New Issue