mirror of https://github.com/google/gemma.cpp.git
Minor followup: remainder handling is a single iteration
Also add profiler annotations. PiperOrigin-RevId: 667883774
This commit is contained in:
parent
c4303cd89b
commit
b6d0ca8a14
|
|
@ -161,6 +161,7 @@ static HWY_INLINE hn::Vec<D> Sigmoid(D d, hn::Vec<D> v) {
|
||||||
// Sigmoid using the logistic function 1 / (1 + exp(-x[i]))
|
// Sigmoid using the logistic function 1 / (1 + exp(-x[i]))
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
|
PROFILER_ZONE("ops.Sigmoid");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
hn::Transform(D(), x, size,
|
hn::Transform(D(), x, size,
|
||||||
|
|
@ -170,6 +171,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
||||||
const float* HWY_RESTRICT b,
|
const float* HWY_RESTRICT b,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
|
PROFILER_ZONE("ops.Dot");
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
HWY_DASSERT(size >= hn::Lanes(d));
|
HWY_DASSERT(size >= hn::Lanes(d));
|
||||||
HWY_DASSERT(size % hn::Lanes(d) == 0);
|
HWY_DASSERT(size % hn::Lanes(d) == 0);
|
||||||
|
|
@ -181,6 +183,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
||||||
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
||||||
const float* HWY_RESTRICT a, size_t size) {
|
const float* HWY_RESTRICT a, size_t size) {
|
||||||
|
PROFILER_ZONE("ops.SquaredL2");
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
using V = hn::Vec<decltype(d)>;
|
using V = hn::Vec<decltype(d)>;
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
|
|
@ -203,6 +206,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
|
PROFILER_ZONE("ops.RMSNormF");
|
||||||
constexpr float kEps = 1e-6f;
|
constexpr float kEps = 1e-6f;
|
||||||
float ss = SquaredL2(x, size);
|
float ss = SquaredL2(x, size);
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
||||||
|
|
@ -216,6 +220,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
|
PROFILER_ZONE("ops.RMSNormBF16");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
constexpr float kEps = 1e-6f;
|
constexpr float kEps = 1e-6f;
|
||||||
|
|
@ -246,6 +251,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
// float -> float; simple loop.
|
// float -> float; simple loop.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
||||||
|
PROFILER_ZONE("ops.RMSNormInplaceF");
|
||||||
constexpr float kEps = 1e-6f;
|
constexpr float kEps = 1e-6f;
|
||||||
float ss = SquaredL2(inout, size);
|
float ss = SquaredL2(inout, size);
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
||||||
|
|
@ -259,6 +265,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout,
|
const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout,
|
||||||
const size_t size) {
|
const size_t size) {
|
||||||
|
PROFILER_ZONE("ops.RMSNormInplaceBF");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
const hn::Repartition<float, decltype(dbf)> df32;
|
||||||
|
|
@ -288,6 +295,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
||||||
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
||||||
|
PROFILER_ZONE("ops.RMSNormF F BF");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
const hn::Repartition<float, decltype(dbf)> df32;
|
||||||
|
|
@ -316,6 +324,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||||
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
||||||
|
PROFILER_ZONE("ops.RMSNormF BF BF");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
const hn::Repartition<float, decltype(dbf)> df32;
|
||||||
|
|
@ -343,6 +352,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
|
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
|
||||||
|
PROFILER_ZONE("ops.AddAbsolutePositionalEmbeddings");
|
||||||
const size_t num_timescales = dim_model / 2;
|
const size_t num_timescales = dim_model / 2;
|
||||||
const float log_timescale_increment =
|
const float log_timescale_increment =
|
||||||
logf(10000.0f) /
|
logf(10000.0f) /
|
||||||
|
|
@ -433,8 +443,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
|
||||||
const D d;
|
const D d;
|
||||||
|
|
||||||
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
|
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
|
||||||
size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
|
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
|
||||||
for (size_t dim = 0; dim < vectorizable_dims; dim += hn::Lanes(d)) {
|
size_t dim = 0;
|
||||||
|
for (; dim < vectorizable_dims; dim += hn::Lanes(d)) {
|
||||||
// Compute thetas
|
// Compute thetas
|
||||||
V pos_vec = hn::Set(d, pos);
|
V pos_vec = hn::Set(d, pos);
|
||||||
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
|
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
|
||||||
|
|
@ -460,10 +471,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
|
||||||
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
|
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vectorize computation for remaining dims.
|
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||||
size_t remaining_dims = half_dim_qkv - vectorizable_dims;
|
const size_t remaining_dims = half_dim_qkv - dim;
|
||||||
for (size_t dim = vectorizable_dims; dim < half_dim_qkv;
|
HWY_DASSERT(remaining_dims < hn::Lanes(d)); // at most one iteration
|
||||||
dim += hn::Lanes(d)) {
|
if (remaining_dims != 0) {
|
||||||
// Compute thetas
|
// Compute thetas
|
||||||
V pos_vec = hn::Set(d, pos);
|
V pos_vec = hn::Set(d, pos);
|
||||||
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);
|
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue