mirror of https://github.com/google/gemma.cpp.git
Reduce duplication in Config* by inheriting no-SSM
PiperOrigin-RevId: 643030629
This commit is contained in:
parent
ea525da967
commit
c15ff9529c
|
|
@ -36,7 +36,12 @@ ByteStorageT AllocateSizeof() {
|
|||
}
|
||||
|
||||
// Model variants: see configs.h for details.
|
||||
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
|
||||
enum class Model {
|
||||
GEMMA_2B,
|
||||
GEMMA_7B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY,
|
||||
};
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
|
|
|
|||
|
|
@ -73,8 +73,18 @@ constexpr size_t NumLayersOfTypeBefore(
|
|||
return count;
|
||||
}
|
||||
|
||||
struct ConfigNoSSM {
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma7B {
|
||||
struct ConfigGemma7B : public ConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -96,19 +106,10 @@ struct ConfigGemma7B {
|
|||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2B {
|
||||
struct ConfigGemma2B : public ConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -118,10 +119,8 @@ struct ConfigGemma2B {
|
|||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers =
|
||||
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||
static constexpr int kGriffinLayers =
|
||||
NumLayersOfTypeBefore(kLayerConfig,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
kLayers);
|
||||
static constexpr int kGriffinLayers = NumLayersOfTypeBefore(
|
||||
kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers);
|
||||
static constexpr int kModelDim = 2048;
|
||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||
static constexpr int kHeads = 8;
|
||||
|
|
@ -130,19 +129,10 @@ struct ConfigGemma2B {
|
|||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemmaTiny {
|
||||
struct ConfigGemmaTiny : public ConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 32;
|
||||
|
|
@ -164,15 +154,6 @@ struct ConfigGemmaTiny {
|
|||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
|
|
|
|||
Loading…
Reference in New Issue