add StaticCast

This commit is contained in:
enum-class 2024-02-29 21:00:54 +08:00
parent 06dd013397
commit 626be6deab
1 changed files with 32 additions and 21 deletions

53
ops.h
View File

@ -57,6 +57,16 @@ HWY_INLINE constexpr size_t MaxCols() {
return 2048; return 2048;
} }
template <typename To, typename From>
HWY_INLINE constexpr std::enable_if_t<
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
StaticCast(From from) noexcept {
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
return static_cast<To>(static_cast<int64_t>(from));
else
return static_cast<To>(from);
}
template <size_t kOuter> template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() { HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one // Aim for 128 work items to reduce pool overhead. Must be at least one
@ -341,7 +351,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) { float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size); float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<float>(size) + eps); ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) { for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here // Note 1.0f centering here
out[j] = (1.0f + weight[j]) * (ss * x[j]); out[j] = (1.0f + weight[j]) * (ss * x[j]);
@ -353,7 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) { float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size); float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<float>(size) + eps); ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) { for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here // Note 1.0f centering here
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]); out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
@ -364,7 +374,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) { const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
float ss = SquaredL2(inout, size); float ss = SquaredL2(inout, size);
ss = 1.0f / sqrtf(ss / static_cast<float>(size) + eps); ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) { for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here // Note 1.0f centering here
inout[j] = (1.0f + weight[j]) * (ss * inout[j]); inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
@ -383,7 +393,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
const float ss = SquaredL2(inout, size); const float ss = SquaredL2(inout, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<float>(size) + eps)); const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) { for (size_t i = 0; i < size; i += 2 * N32) {
@ -411,7 +422,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size); const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<float>(size) + eps)); const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) { for (size_t i = 0; i < size; i += 2 * N32) {
@ -438,7 +450,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size); const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<float>(size) + eps)); const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) { for (size_t i = 0; i < size; i += 2 * N32) {
@ -459,14 +472,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
const size_t num_timescales = dim_model / 2; const size_t num_timescales = dim_model / 2;
const float log_timescale_increment = const float log_timescale_increment =
logf(10000.0f) / logf(10000.0f) /
(num_timescales != 0 (num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
? static_cast<float>(static_cast<int>(num_timescales) - 1)
: 1.0f);
for (size_t dim = 0; dim < num_timescales; ++dim) { for (size_t dim = 0; dim < num_timescales; ++dim) {
const float inv_timescale = const float inv_timescale =
expf(static_cast<float>(dim) * -log_timescale_increment); expf(StaticCast<float>(dim) * -log_timescale_increment);
x[dim] += sinf(static_cast<float>(pos) * inv_timescale); x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
x[num_timescales + dim] += cosf(static_cast<float>(pos) * inv_timescale); x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
} }
} }
@ -475,11 +486,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) { for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) / const float freq_exponents =
static_cast<float>(dim_qkv); StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents); const float timescale = powf(10000.0f, freq_exponents);
const float theta = static_cast<float>(pos) / timescale; const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta); const float cos_val = cosf(theta);
const float sin_val = sinf(theta); const float sin_val = sinf(theta);
const float x0 = x[dim]; const float x0 = x[dim];
@ -496,11 +507,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) { for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) / const float freq_exponents =
static_cast<float>(dim_qkv); StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents); const float timescale = powf(10000.0f, freq_exponents);
const float theta = static_cast<float>(pos) / timescale; const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta); const float cos_val = cosf(theta);
const float sin_val = sinf(theta); const float sin_val = sinf(theta);
const float x0 = x[dim]; const float x0 = x[dim];
@ -674,18 +685,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1] std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
std::array<int, k> indices{}; std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) { for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) { if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast<int>(i))) {
continue; continue;
} }
for (size_t j = 0; j < k; ++j) { for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) { if (probabilities[i] > top_k[j] && accept_token(StaticCast<int>(i))) {
// shift elements by 1, insert the new value, move on to next value // shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) { for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1]; top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1]; indices[idx] = indices[idx - 1];
} }
top_k[j] = probabilities[i]; top_k[j] = probabilities[i];
indices[j] = static_cast<int>(i); indices[j] = StaticCast<int>(i);
break; break;
} }
} }