diff --git a/gemma/attention.cc b/gemma/attention.cc index 84e41b2..7d58980 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -299,19 +299,21 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, layer_idx * cache_layer_size + head * qkv_dim * 2; + HWY_ALIGN float kv_f32[2 * kMaxQKVDim]; + const hn::ScalableTag df; + DecompressAndZeroPad(df, MakeSpan(kv, 2 * qkv_dim), 0, kv_f32, + 2 * qkv_dim); + // Apply further processing to K. if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), 0, kv, qkv_dim); + RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim); }); } - HWY_ALIGN float kv_f32[kMaxQKVDim]; - const hn::ScalableTag df; - DecompressAndZeroPad(df, MakeSpan(kv, qkv_dim), 0, kv_f32, qkv_dim); PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos); CompressPerThread tls; - Compress(kv_f32, qkv_dim, tls, MakeSpan(kv, qkv_dim), 0); + Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a585532..1f21ec0 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -480,7 +480,10 @@ static void GenerateT(const ModelConfig& config, // We use a single divisor, so all sequence lengths must be the same. HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); } - HWY_ASSERT(max_prompt_size < seq_len); + if (max_prompt_size >= seq_len) { + HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.", + max_prompt_size); + } HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); // Lacks a constructor to bulk-set, hence initialized by Prefill* which have diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 6dc1b46..fb6d09a 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -406,7 +406,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( // `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations. // This overload is called if `post_qk == PostQKType::HalfRope`. static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( - float* HWY_RESTRICT x, size_t dim_qkv, + float* HWY_RESTRICT x, const size_t dim_qkv, const float* HWY_RESTRICT inv_timescale, const int pos) { PROFILER_ZONE("ops.Rope"); HWY_DASSERT(dim_qkv % 2 == 0); @@ -430,13 +430,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( hn::SinCos(df, vtheta, vsin_theta, vcos_theta); // Scale input with rotations. - VF vx0 = hn::LoadU(df, x + dim); - VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv); - vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); - vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); + const VF vx0 = hn::LoadU(df, x + dim); + const VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv); + const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); + const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); - hn::StoreU(vx0, df, x + dim); - hn::StoreU(vx1, df, x + dim + half_dim_qkv); + hn::StoreU(vout0, df, x + dim); + hn::StoreU(vout1, df, x + dim + half_dim_qkv); } // Vectorize computation for remaining dims - same as above, but with LoadN. @@ -452,19 +452,19 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( hn::SinCos(df, vtheta, vsin_theta, vcos_theta); // Scale input with rotations. - VF vx0 = hn::LoadN(df, x + dim, remaining_dims); - VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims); - vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); - vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); + const VF vx0 = hn::LoadN(df, x + dim, remaining_dims); + const VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims); + const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); + const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); - hn::StoreN(vx0, df, x + dim, remaining_dims); - hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims); + hn::StoreN(vout0, df, x + dim, remaining_dims); + hn::StoreN(vout1, df, x + dim + half_dim_qkv, remaining_dims); } } // `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations. static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( - const float mul, float* HWY_RESTRICT x, size_t dim_qkv, + const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, const float* HWY_RESTRICT inv_timescale, const int pos) { PROFILER_ZONE("ops.RopeAndMulBy"); HWY_DASSERT(dim_qkv % 2 == 0); @@ -489,13 +489,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( hn::SinCos(df, vtheta, vsin_theta, vcos_theta); // Scale input with rotations and multiply with constant. - VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim)); - VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv)); - vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); - vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); + const VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim)); + const VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv)); + const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); + const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); - hn::StoreU(vx0, df, x + dim); - hn::StoreU(vx1, df, x + dim + half_dim_qkv); + hn::StoreU(vout0, df, x + dim); + hn::StoreU(vout1, df, x + dim + half_dim_qkv); } // Vectorize computation for remaining dims - same as above, but with LoadN. @@ -511,14 +511,14 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( hn::SinCos(df, vtheta, vsin_theta, vcos_theta); // Scale input with rotations and multiply with constant. - VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims)); - VF vx1 = + const VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims)); + const VF vx1 = hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims)); - vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); - vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); + const VF vout0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); + const VF vout1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); - hn::StoreN(vx0, df, x + dim, remaining_dims); - hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims); + hn::StoreN(vout0, df, x + dim, remaining_dims); + hn::StoreN(vout1, df, x + dim + half_dim_qkv, remaining_dims); } } diff --git a/ops/ops_test.cc b/ops/ops_test.cc index c424d7a..d2cf821 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -83,48 +83,44 @@ T Random(hwy::RandomState& rng) { HWY_MAX(hwy::ConvertScalarTo(hwy::LowestValue()), val)); } -HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other, +HWY_NOINLINE void SimpleAddFrom(const float* HWY_RESTRICT other, float* HWY_RESTRICT x, size_t size) { for (size_t i = 0; i < size; ++i) { x[i] += other[i]; } } -HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, size_t size, - size_t max_pos) { - HWY_DASSERT(max_pos <= size); - for (size_t i = 0; i < max_pos; ++i) { +HWY_NOINLINE void SimpleMulBy(const float* HWY_RESTRICT other, + float* HWY_RESTRICT x, size_t size) { + for (size_t i = 0; i < size; ++i) { x[i] *= other[i]; } } -HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size, - size_t max_pos) { - for (size_t i = 0; i < max_pos; ++i) { +HWY_NOINLINE void SimpleMulByConst(float c, float* HWY_RESTRICT x, + size_t size) { + for (size_t i = 0; i < size; ++i) { x[i] *= c; } } -HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x, +HWY_NOINLINE void SimpleMulByConstAndAdd(float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out, size_t size) { for (size_t i = 0; i < size; ++i) { out[i] += x[i] * c; } } -HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size, - size_t mask_pos) { +HWY_NOINLINE void SimpleSoftmax(float* HWY_RESTRICT x, size_t size) { HWY_DASSERT(size != 0); - HWY_DASSERT(mask_pos <= size); float sum = 0.0; - const float maxval = *std::max_element(x, x + mask_pos); - for (size_t i = 0; i < mask_pos; ++i) { + const float maxval = *std::max_element(x, x + size); + for (size_t i = 0; i < size; ++i) { x[i] = std::exp(x[i] - maxval); sum += x[i]; } const float scale = 1.0f / sum; - for (size_t i = 0; i < mask_pos; ++i) { + for (size_t i = 0; i < size; ++i) { x[i] *= scale; } } @@ -169,7 +165,7 @@ struct TestAddFrom { o[i] = Random(rng); } - SourceAddFrom(o, e, count); + SimpleAddFrom(o, e, count); AddFrom(o, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, @@ -177,38 +173,6 @@ struct TestAddFrom { } }; -struct TestMulBy { - template - void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, - hwy::RandomState& rng) { - using T = hn::TFromD; - - hwy::AlignedFreeUniquePtr px = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr pe = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr po = - hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); - HWY_ASSERT(px && pe && po); - - T* x = px.get() + misalign_a; - T* e = pe.get() + misalign_a; - T* o = po.get() + misalign_b; - - for (size_t i = 0; i < count; ++i) { - x[i] = Random(rng); - e[i] = x[i]; - o[i] = Random(rng); - } - - SourceMulBy(o, e, count, count); - MulBy(o, x, count, count); - - hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, - __LINE__); - } -}; - struct TestMulByConstAndAdd { template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, @@ -234,7 +198,7 @@ struct TestMulByConstAndAdd { } T constant = Random(rng); - SourceMulByConstAndAdd(constant, o, e, count); + SimpleMulByConstAndAdd(constant, o, e, count); MulByConstAndAdd(constant, o, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, @@ -264,8 +228,8 @@ struct TestMulByConst { } T constant = Random(rng); - SourceMulByConst(constant, e, count, count); - MulByConst(constant, x, count, count); + SimpleMulByConst(constant, e, count); + MulByConst(constant, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -294,8 +258,8 @@ struct TestSoftmax { e[i] = x[i]; } - SourceSoftmax(e, count, count); - Softmax(x, count, count); + SimpleSoftmax(e, count); + Softmax(x, count); T sum = 0.0f; for (size_t i = 0; i < count; ++i) { @@ -331,10 +295,6 @@ void TestAllAddFrom() { hn::ForPartialVectors>()(float()); } -void TestAllMulBy() { - hn::ForPartialVectors>()(float()); -} - void TestAllMulByConst() { hn::ForPartialVectors>()(float()); } @@ -371,8 +331,8 @@ void TestSigmoid() { } static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( - const float mul, float* HWY_RESTRICT x, size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, int pos) { + const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, + const float* HWY_RESTRICT inv_timescale, const int pos) { HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { @@ -387,9 +347,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( } void TestRopeAndMulBy() { - ModelConfig config(Model::GEMMA2_9B, Type::kSFP, - ChooseWrapping(Model::GEMMA2_9B)); - int dim_qkv = config.layer_configs[0].qkv_dim; + const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, + ChooseWrapping(Model::GEMMA2_9B)); + const size_t dim_qkv = config.layer_configs[0].qkv_dim; MatStorageT x("x", dim_qkv); std::mt19937 gen; @@ -397,44 +357,58 @@ void TestRopeAndMulBy() { std::normal_distribution r{0.0, 5.0}; auto random_float = [&r, &gen] { return r(gen); }; - for (int i = 0; i < dim_qkv; ++i) { + for (size_t i = 0; i < dim_qkv; ++i) { x.Row(0)[i] = random_float(); } const float qmul = AttentionActivations::ChooseQueryScale(config); - const float kmul = 1.0; + constexpr float kmul = 1.0f; MatStorageT qexpected("qexpected", dim_qkv); MatStorageT qactual("qactual", dim_qkv); MatStorageT kexpected("kexpected", dim_qkv); MatStorageT kactual("kactual", dim_qkv); + MatStorageT kactual2("kactual2", dim_qkv); MatStorageT inv_timescale = CreateInvTimescale( config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk == PostQKType::HalfRope); // Assert VectorizedRope computation is same as regular rope at different pos. - for (int pos = 1; pos < 500; pos++) { - // Rope'd Q embeddings - CopyMat(x, qactual); + for (size_t pos = 1; pos < 500; pos++) { + // Rope'd Q embeddings with query scale CopyMat(x, qexpected); + CopyMat(x, qactual); ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), pos); RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); + for (size_t i = 0; i < dim_qkv; ++i) { + EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; + } - for (int i = 0; i < dim_qkv; ++i) { - EXPECT_NEAR(qactual.Row(0)[i], qexpected.Row(0)[i], 1e-4) - << "qIndex:" << i << "qInput:" << qactual.Row(0)[i]; + // Same without query scale + CopyMat(x, qexpected); + CopyMat(x, qactual); + ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), + pos); + Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); + for (size_t i = 0; i < dim_qkv; ++i) { + EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; } // Rope'd K embeddings - CopyMat(x, kactual); CopyMat(x, kexpected); + CopyMat(x, kactual); + CopyMat(x, kactual2); ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0), pos); RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); + static_assert(kmul == 1.0f, ""); + Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos); - for (int i = 0; i < dim_qkv; ++i) { - EXPECT_NEAR(kactual.Row(0)[i], kexpected.Row(0)[i], 1e-4) - << "kIndex:" << i << "kInput:" << kactual.Row(0)[i]; + for (size_t i = 0; i < dim_qkv; ++i) { + EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i; + } + for (size_t i = 0; i < dim_qkv; ++i) { + EXPECT_NEAR(kexpected.Row(0)[i], kactual2.Row(0)[i], 1e-4) << " " << i; } } } @@ -662,7 +636,6 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(OpsTest); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom); -HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);