mirror of https://github.com/google/gemma.cpp.git
Fix Griffin model:
- use HalfRope position encodings - zero-initialize the caches for each Generate at position 0 The lack of the latter made the tests in gemma_test dependent on each other. PiperOrigin-RevId: 694509054
This commit is contained in:
parent
d4050a2917
commit
e54d9cbddd
|
|
@ -246,7 +246,7 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
EXPECT_NEAR(entropy, 2.8f, 0.2f);
|
EXPECT_NEAR(entropy, 2.8f, 0.2f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GRIFFIN_2B:
|
case gcpp::Model::GRIFFIN_2B:
|
||||||
EXPECT_NEAR(entropy, 1.57f, 0.02f);
|
EXPECT_NEAR(entropy, 2.61f, 0.02f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA2_2B:
|
case gcpp::Model::GEMMA2_2B:
|
||||||
EXPECT_NEAR(entropy, 1.14f, 0.02f);
|
EXPECT_NEAR(entropy, 1.14f, 0.02f);
|
||||||
|
|
@ -277,7 +277,7 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||||
EXPECT_NEAR(entropy, 1.07f, 0.05f);
|
EXPECT_NEAR(entropy, 1.07f, 0.05f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GRIFFIN_2B:
|
case gcpp::Model::GRIFFIN_2B:
|
||||||
EXPECT_NEAR(entropy, 2.09f, 0.02f);
|
EXPECT_NEAR(entropy, 1.62f, 0.02f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA2_2B:
|
case gcpp::Model::GEMMA2_2B:
|
||||||
EXPECT_NEAR(entropy, 0.49f, 0.02f);
|
EXPECT_NEAR(entropy, 0.49f, 0.02f);
|
||||||
|
|
@ -308,7 +308,7 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||||
EXPECT_NEAR(entropy, 0.75f, 0.1f);
|
EXPECT_NEAR(entropy, 0.75f, 0.1f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GRIFFIN_2B:
|
case gcpp::Model::GRIFFIN_2B:
|
||||||
EXPECT_NEAR(entropy, 0.86f, 0.02f);
|
EXPECT_NEAR(entropy, 0.71f, 0.02f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA2_2B:
|
case gcpp::Model::GEMMA2_2B:
|
||||||
EXPECT_NEAR(entropy, 0.20f, 0.02f);
|
EXPECT_NEAR(entropy, 0.20f, 0.02f);
|
||||||
|
|
|
||||||
|
|
@ -183,7 +183,7 @@ static ModelConfig ConfigGriffin2B() {
|
||||||
.softmax_attn_output_biases = true,
|
.softmax_attn_output_biases = true,
|
||||||
.type = LayerAttentionType::kGriffinRecurrentBlock,
|
.type = LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
.activation = ActivationType::Gelu,
|
.activation = ActivationType::Gelu,
|
||||||
.post_qk = PostQKType::Rope,
|
.post_qk = PostQKType::HalfRope,
|
||||||
};
|
};
|
||||||
config.layer_configs = {26, layer_config};
|
config.layer_configs = {26, layer_config};
|
||||||
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
|
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
|
||||||
|
|
|
||||||
|
|
@ -397,7 +397,11 @@ void AssertMatch(const ModelConfig& config) {
|
||||||
ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm);
|
ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm);
|
||||||
ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type);
|
ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type);
|
||||||
ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation);
|
ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation);
|
||||||
ASSERT_EQ(TConfig::kPostQK, config.layer_configs[i].post_qk);
|
PostQKType post_qk = TConfig::kPostQK;
|
||||||
|
if (TConfig::kUseHalfRope) {
|
||||||
|
post_qk = PostQKType::HalfRope;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(post_qk, config.layer_configs[i].post_qk);
|
||||||
}
|
}
|
||||||
|
|
||||||
ASSERT_EQ(TConfig::kAttentionWindowSizes.size(),
|
ASSERT_EQ(TConfig::kAttentionWindowSizes.size(),
|
||||||
|
|
|
||||||
|
|
@ -1240,8 +1240,12 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
const QueriesPos& queries_prefix_end,
|
const QueriesPos& queries_prefix_end,
|
||||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
const size_t vocab_size = model.Config().vocab_size;
|
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||||
const ModelWeightsPtrs<T>& weights = *model.GetWeightsOfType<T>();
|
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||||
|
if (queries_pos_in[i] == 0) {
|
||||||
|
kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Copy so we can increment without requiring users to pass in a mutable span.
|
// Copy so we can increment without requiring users to pass in a mutable span.
|
||||||
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
||||||
|
|
@ -1268,7 +1272,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
HWY_ASSERT(queries_pos_in.size() == num_queries);
|
HWY_ASSERT(queries_pos_in.size() == num_queries);
|
||||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||||
|
const ModelWeightsPtrs<T>& weights = *model.GetWeightsOfType<T>();
|
||||||
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
||||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||||
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
|
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
|
||||||
|
|
@ -1314,6 +1318,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
0.0f);
|
0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const size_t vocab_size = model.Config().vocab_size;
|
||||||
const double gen_start = hwy::platform::Now();
|
const double gen_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||||
// Decode generates one token per query and increments queries_mutable_pos.
|
// Decode generates one token per query and increments queries_mutable_pos.
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,17 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
void KVCache::ZeroGriffinCache() {
|
||||||
|
if (conv1d_cache_size != 0) {
|
||||||
|
hwy::ZeroBytes(conv1d_cache.get(),
|
||||||
|
conv1d_cache_size * sizeof(conv1d_cache[0]));
|
||||||
|
}
|
||||||
|
if (rglru_cache_size != 0) {
|
||||||
|
hwy::ZeroBytes(rglru_cache.get(),
|
||||||
|
rglru_cache_size * sizeof(rglru_cache[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||||
// prefill at a time.
|
// prefill at a time.
|
||||||
KVCache KVCache::Create(const ModelConfig& weights_config,
|
KVCache KVCache::Create(const ModelConfig& weights_config,
|
||||||
|
|
@ -37,9 +48,9 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
|
||||||
kv_cache.kv_cache =
|
kv_cache.kv_cache =
|
||||||
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
|
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
|
||||||
}
|
}
|
||||||
size_t num_griffin_layers = weights_config.NumLayersOfType(
|
|
||||||
LayerAttentionType::kGriffinRecurrentBlock);
|
|
||||||
|
|
||||||
|
const size_t num_griffin_layers = weights_config.NumLayersOfType(
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock);
|
||||||
// TODO(patrickms): Add query batching support for Griffin.
|
// TODO(patrickms): Add query batching support for Griffin.
|
||||||
if (num_griffin_layers > 0) {
|
if (num_griffin_layers > 0) {
|
||||||
size_t conv1d_width = 0;
|
size_t conv1d_width = 0;
|
||||||
|
|
@ -49,20 +60,18 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
|
||||||
const size_t conv1d_cache_size =
|
const size_t conv1d_cache_size =
|
||||||
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
|
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
|
||||||
weights_config.model_dim;
|
weights_config.model_dim;
|
||||||
|
kv_cache.conv1d_cache_size = conv1d_cache_size;
|
||||||
if (conv1d_cache_size != 0) {
|
if (conv1d_cache_size != 0) {
|
||||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
||||||
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
|
|
||||||
conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t rglru_cache_size =
|
const size_t rglru_cache_size =
|
||||||
num_griffin_layers * weights_config.model_dim;
|
num_griffin_layers * weights_config.model_dim;
|
||||||
|
kv_cache.rglru_cache_size = rglru_cache_size;
|
||||||
if (rglru_cache_size != 0) {
|
if (rglru_cache_size != 0) {
|
||||||
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
||||||
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
|
|
||||||
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
|
|
||||||
}
|
}
|
||||||
} // kGriffinLayers
|
} // num_griffin_layers
|
||||||
|
|
||||||
return kv_cache;
|
return kv_cache;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,9 +31,15 @@ struct KVCache {
|
||||||
|
|
||||||
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||||
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
|
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
|
||||||
|
size_t conv1d_cache_size = 0;
|
||||||
|
|
||||||
// kModelDim * kGriffinLayers
|
// kModelDim * kGriffinLayers
|
||||||
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
||||||
|
size_t rglru_cache_size = 0;
|
||||||
|
|
||||||
|
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
|
||||||
|
// and rglru_cache.
|
||||||
|
void ZeroGriffinCache();
|
||||||
|
|
||||||
static KVCache Create(const ModelConfig& weights_config,
|
static KVCache Create(const ModelConfig& weights_config,
|
||||||
size_t prefill_tbatch_size);
|
size_t prefill_tbatch_size);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue