From 2d4de6b08b63e206d12ab34340cb476a23c8d0ba Mon Sep 17 00:00:00 2001 From: Paul Chang Date: Thu, 25 Apr 2024 09:31:29 -0700 Subject: [PATCH] Support absolute positional embeddings from vanilla transformer PiperOrigin-RevId: 628100831 --- gemma/configs.h | 3 +++ gemma/gemma.cc | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/gemma/configs.h b/gemma/configs.h index 9388c9d..bedecee 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -90,6 +90,7 @@ struct ConfigGemma7B { static constexpr int kKVHeads = 16; // standard MHA static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; // SSM config. static constexpr int kConv1dWidth = 0; @@ -120,6 +121,7 @@ struct ConfigGemma2B { static constexpr int kKVHeads = 1; static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; // SSM config. static constexpr int kConv1dWidth = 0; @@ -178,6 +180,7 @@ struct ConfigGriffin2B { static constexpr int kKVHeads = 1; static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; // SSM config. static constexpr int kConv1dWidth = 4; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 1517771..58078a9 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -896,6 +896,11 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, activations.x.data() + token_idx * kModelDim, kModelDim); MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, kModelDim); + if constexpr (TConfig::kAbsolutePE) { + AddAbsolutePositionalEmbeddings( + activations.x.data() + token_idx * kModelDim, TConfig::kModelDim, + pos); + }; }); for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { @@ -958,6 +963,10 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights, GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = EmbeddingScaling(); MulByConst(kEmbScaling, activations.x.data(), kModelDim); + if constexpr (TConfig::kAbsolutePE) { + AddAbsolutePositionalEmbeddings(activations.x.data(), TConfig::kModelDim, + pos); + }; for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { auto type = TConfig::kLayerConfig[layer]; const auto* layer_weights = weights.GetLayer(layer);