Support absolute positional embeddings from vanilla transformer

PiperOrigin-RevId: 628100831
This commit is contained in:
Paul Chang 2024-04-25 09:31:29 -07:00 committed by Copybara-Service
parent 75eca87039
commit 2d4de6b08b
2 changed files with 12 additions and 0 deletions

View File

@ -90,6 +90,7 @@ struct ConfigGemma7B {
static constexpr int kKVHeads = 16; // standard MHA static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
// SSM config. // SSM config.
static constexpr int kConv1dWidth = 0; static constexpr int kConv1dWidth = 0;
@ -120,6 +121,7 @@ struct ConfigGemma2B {
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
// SSM config. // SSM config.
static constexpr int kConv1dWidth = 0; static constexpr int kConv1dWidth = 0;
@ -178,6 +180,7 @@ struct ConfigGriffin2B {
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
// SSM config. // SSM config.
static constexpr int kConv1dWidth = 4; static constexpr int kConv1dWidth = 4;

View File

@ -896,6 +896,11 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
kModelDim); kModelDim);
if constexpr (TConfig::kAbsolutePE) {
AddAbsolutePositionalEmbeddings(
activations.x.data() + token_idx * kModelDim, TConfig::kModelDim,
pos);
};
}); });
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
@ -958,6 +963,10 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>(); EmbeddingScaling<TConfig>();
MulByConst(kEmbScaling, activations.x.data(), kModelDim); MulByConst(kEmbScaling, activations.x.data(), kModelDim);
if constexpr (TConfig::kAbsolutePE) {
AddAbsolutePositionalEmbeddings(activations.x.data(), TConfig::kModelDim,
pos);
};
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer]; auto type = TConfig::kLayerConfig[layer];
const auto* layer_weights = weights.GetLayer(layer); const auto* layer_weights = weights.GetLayer(layer);