mirror of https://github.com/google/gemma.cpp.git
parent
a24eda8d02
commit
1982a6ba00
|
|
@ -170,6 +170,9 @@ cc_library(
|
||||||
"gemma/instantiations/gr2b_bf16.cc",
|
"gemma/instantiations/gr2b_bf16.cc",
|
||||||
"gemma/instantiations/gr2b_f32.cc",
|
"gemma/instantiations/gr2b_f32.cc",
|
||||||
"gemma/instantiations/gr2b_sfp.cc",
|
"gemma/instantiations/gr2b_sfp.cc",
|
||||||
|
"gemma/instantiations/gemma2_2b_bf16.cc",
|
||||||
|
"gemma/instantiations/gemma2_2b_f32.cc",
|
||||||
|
"gemma/instantiations/gemma2_2b_sfp.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"gemma/activations.h",
|
"gemma/activations.h",
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,9 @@ set(SOURCES
|
||||||
gemma/instantiations/tiny_bf16.cc
|
gemma/instantiations/tiny_bf16.cc
|
||||||
gemma/instantiations/tiny_f32.cc
|
gemma/instantiations/tiny_f32.cc
|
||||||
gemma/instantiations/tiny_sfp.cc
|
gemma/instantiations/tiny_sfp.cc
|
||||||
|
gemma/instantiations/gemma2_2b_bf16.cc
|
||||||
|
gemma/instantiations/gemma2_2b_f32.cc
|
||||||
|
gemma/instantiations/gemma2_2b_sfp.cc
|
||||||
gemma/kv_cache.cc
|
gemma/kv_cache.cc
|
||||||
gemma/kv_cache.h
|
gemma/kv_cache.h
|
||||||
gemma/tokenizer.cc
|
gemma/tokenizer.cc
|
||||||
|
|
|
||||||
|
|
@ -29,12 +29,13 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
constexpr const char* kModelFlags[] = {
|
constexpr const char* kModelFlags[] = {
|
||||||
"2b-pt", "2b-it", // Gemma 2B
|
"2b-pt", "2b-it", // Gemma 2B
|
||||||
"7b-pt", "7b-it", // Gemma 7B
|
"7b-pt", "7b-it", // Gemma 7B
|
||||||
"9b-pt", "9b-it", // Gemma 9B
|
"9b-pt", "9b-it", // Gemma 9B
|
||||||
"27b-pt", "27b-it", // Gemma 27B
|
"27b-pt", "27b-it", // Gemma 27B
|
||||||
"gr2b-pt", "gr2b-it", // RecurrentGemma
|
"gr2b-pt", "gr2b-it", // RecurrentGemma
|
||||||
"tiny", // Gemma Tiny (mostly for debugging)
|
"tiny", // Gemma Tiny (mostly for debugging)
|
||||||
|
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
|
||||||
};
|
};
|
||||||
constexpr Model kModelTypes[] = {
|
constexpr Model kModelTypes[] = {
|
||||||
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
||||||
|
|
@ -43,6 +44,7 @@ constexpr Model kModelTypes[] = {
|
||||||
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
|
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
|
||||||
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
|
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
|
||||||
Model::GEMMA_TINY, // Gemma Tiny
|
Model::GEMMA_TINY, // Gemma Tiny
|
||||||
|
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
|
||||||
};
|
};
|
||||||
constexpr ModelTraining kModelTraining[] = {
|
constexpr ModelTraining kModelTraining[] = {
|
||||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
|
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
|
||||||
|
|
@ -51,6 +53,7 @@ constexpr ModelTraining kModelTraining[] = {
|
||||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
|
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
|
||||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
|
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
|
||||||
ModelTraining::GEMMA_IT, // Gemma Tiny
|
ModelTraining::GEMMA_IT, // Gemma Tiny
|
||||||
|
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B2
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr size_t kNumModelFlags =
|
constexpr size_t kNumModelFlags =
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ enum class Model {
|
||||||
GEMMA_27B,
|
GEMMA_27B,
|
||||||
GRIFFIN_2B,
|
GRIFFIN_2B,
|
||||||
GEMMA_TINY,
|
GEMMA_TINY,
|
||||||
|
GEMMA2_2B,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||||
|
|
@ -99,6 +100,9 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
|
||||||
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...);
|
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||||
case Model::GRIFFIN_2B:
|
case Model::GRIFFIN_2B:
|
||||||
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
|
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||||
|
case Model::GEMMA2_2B:
|
||||||
|
return FuncT<ConfigGemma2_2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||||
|
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
}
|
}
|
||||||
|
|
@ -142,6 +146,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
||||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \
|
GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \
|
||||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \
|
GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \
|
||||||
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
|
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
|
||||||
|
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
|
||||||
static_assert(true, "Allow trailing ;")
|
static_assert(true, "Allow trailing ;")
|
||||||
|
|
||||||
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
|
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
|
||||||
|
|
@ -178,6 +183,11 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
||||||
ARGS; \
|
ARGS; \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
|
case Model::GEMMA2_2B: { \
|
||||||
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2_2B<TWEIGHT>>) \
|
||||||
|
ARGS; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
default: \
|
default: \
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -253,6 +253,28 @@ struct ConfigGemma2B : public ConfigBaseGemmaV1 {
|
||||||
static constexpr bool kAbsolutePE = false;
|
static constexpr bool kAbsolutePE = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TWeight>
|
||||||
|
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
|
||||||
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
|
static constexpr int kSeqLen = 8192;
|
||||||
|
static constexpr int kVocabSize = 256000;
|
||||||
|
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
|
||||||
|
FixedLayerConfig<26>(LayerAttentionType::kGemma);
|
||||||
|
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||||
|
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
|
||||||
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
|
static constexpr int kModelDim = 2304;
|
||||||
|
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
|
||||||
|
static constexpr int kHeads = 8;
|
||||||
|
static constexpr int kKVHeads = 4;
|
||||||
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
static constexpr bool kAbsolutePE = false;
|
||||||
|
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
struct ConfigGemmaTiny : public ConfigNoSSM {
|
struct ConfigGemmaTiny : public ConfigNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE \
|
||||||
|
"gemma/instantiations/gemma2_2b_bf16.cc"
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#define GEMMA_CONFIG ConfigGemma2_2B<hwy::bfloat16_t>
|
||||||
|
#include "gemma/gemma-inl.h"
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE \
|
||||||
|
"gemma/instantiations/gemma2_2b_f32.cc"
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#define GEMMA_CONFIG ConfigGemma2_2B<float>
|
||||||
|
#include "gemma/gemma-inl.h"
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE \
|
||||||
|
"gemma/instantiations/gemma2_2b_sfp.cc"
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#define GEMMA_CONFIG ConfigGemma2_2B<SfpStream>
|
||||||
|
#include "gemma/gemma-inl.h"
|
||||||
Loading…
Reference in New Issue