mirror of https://github.com/google/gemma.cpp.git
Reduce duplication in Config* by inheriting no-SSM
PiperOrigin-RevId: 643030629
This commit is contained in:
parent
ea525da967
commit
c15ff9529c
|
|
@ -36,7 +36,12 @@ ByteStorageT AllocateSizeof() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model variants: see configs.h for details.
|
// Model variants: see configs.h for details.
|
||||||
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
|
enum class Model {
|
||||||
|
GEMMA_2B,
|
||||||
|
GEMMA_7B,
|
||||||
|
GRIFFIN_2B,
|
||||||
|
GEMMA_TINY,
|
||||||
|
};
|
||||||
|
|
||||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||||
|
|
|
||||||
|
|
@ -73,8 +73,18 @@ constexpr size_t NumLayersOfTypeBefore(
|
||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ConfigNoSSM {
|
||||||
|
static constexpr int kConv1dWidth = 0;
|
||||||
|
static constexpr bool kFFBiases = false;
|
||||||
|
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||||
|
static constexpr bool kUseHalfRope = false;
|
||||||
|
static constexpr bool kUseLocalAttention = false;
|
||||||
|
static constexpr bool kInterleaveQKV = true;
|
||||||
|
static constexpr int kNumTensorScales = 0;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
struct ConfigGemma7B {
|
struct ConfigGemma7B : public ConfigNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
|
|
@ -96,19 +106,10 @@ struct ConfigGemma7B {
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
static constexpr bool kAbsolutePE = false;
|
static constexpr bool kAbsolutePE = false;
|
||||||
static constexpr bool kPostNormScale = false;
|
static constexpr bool kPostNormScale = false;
|
||||||
|
|
||||||
// SSM config.
|
|
||||||
static constexpr int kConv1dWidth = 0;
|
|
||||||
static constexpr bool kFFBiases = false;
|
|
||||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
|
||||||
static constexpr bool kUseHalfRope = false;
|
|
||||||
static constexpr bool kUseLocalAttention = false;
|
|
||||||
static constexpr bool kInterleaveQKV = true;
|
|
||||||
static constexpr int kNumTensorScales = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
struct ConfigGemma2B {
|
struct ConfigGemma2B : public ConfigNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
|
|
@ -118,10 +119,8 @@ struct ConfigGemma2B {
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kGemmaLayers =
|
static constexpr int kGemmaLayers =
|
||||||
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||||
static constexpr int kGriffinLayers =
|
static constexpr int kGriffinLayers = NumLayersOfTypeBefore(
|
||||||
NumLayersOfTypeBefore(kLayerConfig,
|
kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers);
|
||||||
LayerAttentionType::kGriffinRecurrentBlock,
|
|
||||||
kLayers);
|
|
||||||
static constexpr int kModelDim = 2048;
|
static constexpr int kModelDim = 2048;
|
||||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||||
static constexpr int kHeads = 8;
|
static constexpr int kHeads = 8;
|
||||||
|
|
@ -130,19 +129,10 @@ struct ConfigGemma2B {
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
static constexpr bool kAbsolutePE = false;
|
static constexpr bool kAbsolutePE = false;
|
||||||
static constexpr bool kPostNormScale = false;
|
static constexpr bool kPostNormScale = false;
|
||||||
|
|
||||||
// SSM config.
|
|
||||||
static constexpr int kConv1dWidth = 0;
|
|
||||||
static constexpr bool kFFBiases = false;
|
|
||||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
|
||||||
static constexpr bool kUseHalfRope = false;
|
|
||||||
static constexpr bool kUseLocalAttention = false;
|
|
||||||
static constexpr bool kInterleaveQKV = true;
|
|
||||||
static constexpr int kNumTensorScales = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
struct ConfigGemmaTiny {
|
struct ConfigGemmaTiny : public ConfigNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = 32;
|
static constexpr int kSeqLen = 32;
|
||||||
|
|
@ -164,15 +154,6 @@ struct ConfigGemmaTiny {
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
static constexpr bool kAbsolutePE = false;
|
static constexpr bool kAbsolutePE = false;
|
||||||
static constexpr bool kPostNormScale = false;
|
static constexpr bool kPostNormScale = false;
|
||||||
|
|
||||||
// SSM config.
|
|
||||||
static constexpr int kConv1dWidth = 0;
|
|
||||||
static constexpr bool kFFBiases = false;
|
|
||||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
|
||||||
static constexpr bool kUseHalfRope = false;
|
|
||||||
static constexpr bool kUseLocalAttention = false;
|
|
||||||
static constexpr bool kInterleaveQKV = true;
|
|
||||||
static constexpr int kNumTensorScales = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TWeight>
|
template <typename TWeight>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue