From 48d0801fb0faac588ec1c1a3702f359ce162d599 Mon Sep 17 00:00:00 2001 From: Apoorv Reddy Date: Wed, 21 Aug 2024 02:21:47 -0700 Subject: [PATCH] Vectorize Rope for qkv dim not evenly divisible by number of lanes. PiperOrigin-RevId: 665776602 --- backprop/optimize_test.cc | 12 +++++++----- ops/ops-inl.h | 35 +++++++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 2c196cf..cc67d3b 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -15,6 +15,7 @@ #include +#include #include #include @@ -81,15 +82,16 @@ TEST(OptimizeTest, GradientDescent) { return reply; }; + // Sanity check of reply tokens. + // 1) Its length should be greater than the prompt. + // 2) The prompt should be a prefix of the reply. auto verify = [&](const Prompt& prompt) { auto context = prompt.context(); std::vector reply = generate(context); bool ok = true; - for (size_t i = 0; ok && i < prompt.tokens.size(); ++i) { - if (i >= reply.size() || reply[i] != prompt.tokens[i]) { - ok = false; - } - } + ok &= (reply.size() > context.size()); + ok &= std::equal(prompt.tokens.begin(), prompt.tokens.end(), + reply.begin(), reply.begin() + prompt.tokens.size()); return ok; }; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 2a1da98..d471889 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -432,8 +432,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy( using V = hn::Vec; const D d; - HWY_DASSERT(half_dim_qkv % hn::Lanes(d) == 0); - for (size_t dim = 0; dim < half_dim_qkv; dim += hn::Lanes(d)) { + // Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes) + size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d)); + for (size_t dim = 0; dim < vectorizable_dims; dim += hn::Lanes(d)) { // Compute thetas V pos_vec = hn::Set(d, pos); V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim); @@ -458,6 +459,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy( hn::StoreU(xout_0_vec, d, x_out + dim); hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv); } + + // Vectorize computation for remaining dims. + size_t remaining_dims = half_dim_qkv - vectorizable_dims; + for (size_t dim = vectorizable_dims; dim < half_dim_qkv; + dim += hn::Lanes(d)) { + // Compute thetas + V pos_vec = hn::Set(d, pos); + V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims); + V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec); + + // Compute rotations. + V cos_theta_vec; + V sin_theta_vec; + hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec); + + // Scale input with rotations and multiply with constant. + V mul_vec = hn::Set(d, mul); + V x0_vec = hn::Mul(mul_vec, hn::LoadN(d, x + dim, remaining_dims)); + V x1_vec = + hn::Mul(mul_vec, hn::LoadN(d, x + dim + half_dim_qkv, remaining_dims)); + + V xout_0_vec = + hn::MulSub(x0_vec, cos_theta_vec, hn::Mul(x1_vec, sin_theta_vec)); + V xout_1_vec = + hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec)); + + // Store + hn::StoreN(xout_0_vec, d, x_out + dim, remaining_dims); + hn::StoreN(xout_1_vec, d, x_out + dim + half_dim_qkv, remaining_dims); + } } static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(