Minor followup: remainder handling is a single iteration

Also add profiler annotations.

PiperOrigin-RevId: 667883774
This commit is contained in:
Jan Wassenberg 2024-08-27 01:19:08 -07:00 committed by Copybara-Service
parent c4303cd89b
commit b6d0ca8a14
1 changed files with 17 additions and 6 deletions

View File

@ -161,6 +161,7 @@ static HWY_INLINE hn::Vec<D> Sigmoid(D d, hn::Vec<D> v) {
// Sigmoid using the logistic function 1 / (1 + exp(-x[i]))
static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
size_t size) {
PROFILER_ZONE("ops.Sigmoid");
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
hn::Transform(D(), x, size,
@ -170,6 +171,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
const float* HWY_RESTRICT b,
size_t size) {
PROFILER_ZONE("ops.Dot");
const hn::ScalableTag<float> d;
HWY_DASSERT(size >= hn::Lanes(d));
HWY_DASSERT(size % hn::Lanes(d) == 0);
@ -181,6 +183,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
const float* HWY_RESTRICT a, size_t size) {
PROFILER_ZONE("ops.SquaredL2");
const hn::ScalableTag<float> d;
using V = hn::Vec<decltype(d)>;
const size_t N = hn::Lanes(d);
@ -203,6 +206,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
float* HWY_RESTRICT out, size_t size) {
PROFILER_ZONE("ops.RMSNormF");
constexpr float kEps = 1e-6f;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
@ -216,6 +220,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
float* HWY_RESTRICT out, size_t size) {
PROFILER_ZONE("ops.RMSNormBF16");
namespace hn = hwy::HWY_NAMESPACE;
constexpr float kEps = 1e-6f;
@ -246,6 +251,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
// float -> float; simple loop.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
PROFILER_ZONE("ops.RMSNormInplaceF");
constexpr float kEps = 1e-6f;
float ss = SquaredL2(inout, size);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
@ -259,6 +265,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout,
const size_t size) {
PROFILER_ZONE("ops.RMSNormInplaceBF");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<hwy::bfloat16_t> dbf;
const hn::Repartition<float, decltype(dbf)> df32;
@ -288,6 +295,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
PROFILER_ZONE("ops.RMSNormF F BF");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<hwy::bfloat16_t> dbf;
const hn::Repartition<float, decltype(dbf)> df32;
@ -316,6 +324,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
PROFILER_ZONE("ops.RMSNormF BF BF");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<hwy::bfloat16_t> dbf;
const hn::Repartition<float, decltype(dbf)> df32;
@ -343,6 +352,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
PROFILER_ZONE("ops.AddAbsolutePositionalEmbeddings");
const size_t num_timescales = dim_model / 2;
const float log_timescale_increment =
logf(10000.0f) /
@ -433,8 +443,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
const D d;
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
for (size_t dim = 0; dim < vectorizable_dims; dim += hn::Lanes(d)) {
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
size_t dim = 0;
for (; dim < vectorizable_dims; dim += hn::Lanes(d)) {
// Compute thetas
V pos_vec = hn::Set(d, pos);
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
@ -460,10 +471,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
}
// Vectorize computation for remaining dims.
size_t remaining_dims = half_dim_qkv - vectorizable_dims;
for (size_t dim = vectorizable_dims; dim < half_dim_qkv;
dim += hn::Lanes(d)) {
// Vectorize computation for remaining dims - same as above, but with LoadN.
const size_t remaining_dims = half_dim_qkv - dim;
HWY_DASSERT(remaining_dims < hn::Lanes(d)); // at most one iteration
if (remaining_dims != 0) {
// Compute thetas
V pos_vec = hn::Set(d, pos);
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);