Reduce duplication in Config* by inheriting no-SSM

PiperOrigin-RevId: 643030629
This commit is contained in:
Jan Wassenberg 2024-06-13 09:48:24 -07:00 committed by Copybara-Service
parent ea525da967
commit c15ff9529c
2 changed files with 21 additions and 35 deletions

View File

@ -36,7 +36,12 @@ ByteStorageT AllocateSizeof() {
}
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
enum class Model {
GEMMA_2B,
GEMMA_7B,
GRIFFIN_2B,
GEMMA_TINY,
};
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class ModelTraining { GEMMA_IT, GEMMA_PT };

View File

@ -73,8 +73,18 @@ constexpr size_t NumLayersOfTypeBefore(
return count;
}
struct ConfigNoSSM {
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
};
template <typename TWeight>
struct ConfigGemma7B {
struct ConfigGemma7B : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
@ -96,19 +106,10 @@ struct ConfigGemma7B {
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
// SSM config.
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
};
template <typename TWeight>
struct ConfigGemma2B {
struct ConfigGemma2B : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
@ -118,10 +119,8 @@ struct ConfigGemma2B {
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kGriffinLayers = NumLayersOfTypeBefore(
kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers);
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
@ -130,19 +129,10 @@ struct ConfigGemma2B {
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
// SSM config.
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
};
template <typename TWeight>
struct ConfigGemmaTiny {
struct ConfigGemmaTiny : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 32;
@ -164,15 +154,6 @@ struct ConfigGemmaTiny {
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
// SSM config.
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
};
template <typename TWeight>