mirror of https://github.com/google/gemma.cpp.git
Support absolute positional embeddings from vanilla transformer
PiperOrigin-RevId: 628100831
This commit is contained in:
parent
75eca87039
commit
2d4de6b08b
|
|
@ -90,6 +90,7 @@ struct ConfigGemma7B {
|
||||||
static constexpr int kKVHeads = 16; // standard MHA
|
static constexpr int kKVHeads = 16; // standard MHA
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
static constexpr bool kAbsolutePE = false;
|
||||||
|
|
||||||
// SSM config.
|
// SSM config.
|
||||||
static constexpr int kConv1dWidth = 0;
|
static constexpr int kConv1dWidth = 0;
|
||||||
|
|
@ -120,6 +121,7 @@ struct ConfigGemma2B {
|
||||||
static constexpr int kKVHeads = 1;
|
static constexpr int kKVHeads = 1;
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
static constexpr bool kAbsolutePE = false;
|
||||||
|
|
||||||
// SSM config.
|
// SSM config.
|
||||||
static constexpr int kConv1dWidth = 0;
|
static constexpr int kConv1dWidth = 0;
|
||||||
|
|
@ -178,6 +180,7 @@ struct ConfigGriffin2B {
|
||||||
static constexpr int kKVHeads = 1;
|
static constexpr int kKVHeads = 1;
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
static constexpr bool kAbsolutePE = false;
|
||||||
|
|
||||||
// SSM config.
|
// SSM config.
|
||||||
static constexpr int kConv1dWidth = 4;
|
static constexpr int kConv1dWidth = 4;
|
||||||
|
|
|
||||||
|
|
@ -896,6 +896,11 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||||
kModelDim);
|
kModelDim);
|
||||||
|
if constexpr (TConfig::kAbsolutePE) {
|
||||||
|
AddAbsolutePositionalEmbeddings(
|
||||||
|
activations.x.data() + token_idx * kModelDim, TConfig::kModelDim,
|
||||||
|
pos);
|
||||||
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
|
|
@ -958,6 +963,10 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
EmbeddingScaling<TConfig>();
|
EmbeddingScaling<TConfig>();
|
||||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
||||||
|
if constexpr (TConfig::kAbsolutePE) {
|
||||||
|
AddAbsolutePositionalEmbeddings(activations.x.data(), TConfig::kModelDim,
|
||||||
|
pos);
|
||||||
|
};
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
auto type = TConfig::kLayerConfig[layer];
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue