mirror of https://github.com/google/gemma.cpp.git
Support absolute positional embeddings from vanilla transformer
PiperOrigin-RevId: 628100831
This commit is contained in:
parent
75eca87039
commit
2d4de6b08b
|
|
@ -90,6 +90,7 @@ struct ConfigGemma7B {
|
|||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
|
|
@ -120,6 +121,7 @@ struct ConfigGemma2B {
|
|||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
|
|
@ -178,6 +180,7 @@ struct ConfigGriffin2B {
|
|||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 4;
|
||||
|
|
|
|||
|
|
@ -896,6 +896,11 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
if constexpr (TConfig::kAbsolutePE) {
|
||||
AddAbsolutePositionalEmbeddings(
|
||||
activations.x.data() + token_idx * kModelDim, TConfig::kModelDim,
|
||||
pos);
|
||||
};
|
||||
});
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
|
|
@ -958,6 +963,10 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
|||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||
EmbeddingScaling<TConfig>();
|
||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
||||
if constexpr (TConfig::kAbsolutePE) {
|
||||
AddAbsolutePositionalEmbeddings(activations.x.data(), TConfig::kModelDim,
|
||||
pos);
|
||||
};
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
|
|
|
|||
Loading…
Reference in New Issue