diff --git a/ops.h b/ops.h index 0cc38f6..5fea717 100644 --- a/ops.h +++ b/ops.h @@ -57,6 +57,16 @@ HWY_INLINE constexpr size_t MaxCols() { return 2048; } +template +HWY_INLINE constexpr std::enable_if_t< + std::is_arithmetic_v && std::is_arithmetic_v, To> +StaticCast(From from) noexcept { + if constexpr (std::is_unsigned_v && std::is_floating_point_v) + return static_cast(static_cast(from)); + else + return static_cast(from); +} + template HWY_INLINE constexpr size_t RowsPerStrip() { // Aim for 128 work items to reduce pool overhead. Must be at least one @@ -341,7 +351,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( float* HWY_RESTRICT out, size_t size) { constexpr float eps = 1e-6f; float ss = SquaredL2(x, size); - ss = 1.0f / sqrtf(ss / static_cast(size) + eps); + ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); for (size_t j = 0; j < size; j++) { // Note 1.0f centering here out[j] = (1.0f + weight[j]) * (ss * x[j]); @@ -353,7 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( float* HWY_RESTRICT out, size_t size) { constexpr float eps = 1e-6f; float ss = SquaredL2(x, size); - ss = 1.0f / sqrtf(ss / static_cast(size) + eps); + ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); for (size_t j = 0; j < size; j++) { // Note 1.0f centering here out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]); @@ -364,7 +374,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) { constexpr float eps = 1e-6f; float ss = SquaredL2(inout, size); - ss = 1.0f / sqrtf(ss / static_cast(size) + eps); + ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); for (size_t j = 0; j < size; j++) { // Note 1.0f centering here inout[j] = (1.0f + weight[j]) * (ss * inout[j]); @@ -383,7 +393,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( constexpr float eps = 1e-6f; const float ss = SquaredL2(inout, size); - const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + const VF vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -411,7 +422,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( constexpr float eps = 1e-6f; const float ss = SquaredL2(x, size); - const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + const VF vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -438,7 +450,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( constexpr float eps = 1e-6f; const float ss = SquaredL2(x, size); - const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + const VF vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -459,14 +472,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( const size_t num_timescales = dim_model / 2; const float log_timescale_increment = logf(10000.0f) / - (num_timescales != 0 - ? static_cast(static_cast(num_timescales) - 1) - : 1.0f); + (num_timescales != 0 ? StaticCast(num_timescales - 1) : 1.0f); for (size_t dim = 0; dim < num_timescales; ++dim) { const float inv_timescale = - expf(static_cast(dim) * -log_timescale_increment); - x[dim] += sinf(static_cast(pos) * inv_timescale); - x[num_timescales + dim] += cosf(static_cast(pos) * inv_timescale); + expf(StaticCast(dim) * -log_timescale_increment); + x[dim] += sinf(StaticCast(pos) * inv_timescale); + x[num_timescales + dim] += cosf(StaticCast(pos) * inv_timescale); } } @@ -475,11 +486,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float freq_exponents = static_cast(2 * static_cast(dim)) / - static_cast(dim_qkv); + const float freq_exponents = + StaticCast(2 * dim) / StaticCast(dim_qkv); // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. const float timescale = powf(10000.0f, freq_exponents); - const float theta = static_cast(pos) / timescale; + const float theta = StaticCast(pos) / timescale; const float cos_val = cosf(theta); const float sin_val = sinf(theta); const float x0 = x[dim]; @@ -496,11 +507,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float freq_exponents = static_cast(2 * static_cast(dim)) / - static_cast(dim_qkv); + const float freq_exponents = + StaticCast(2 * dim) / StaticCast(dim_qkv); // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. const float timescale = powf(10000.0f, freq_exponents); - const float theta = static_cast(pos) / timescale; + const float theta = StaticCast(pos) / timescale; const float cos_val = cosf(theta); const float sin_val = sinf(theta); const float x0 = x[dim]; @@ -674,18 +685,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( std::array top_k{}; // sorted from highest [0], to lowest [k-1] std::array indices{}; for (size_t i = 0; i < vocab_size; ++i) { - if (probabilities[i] < top_k[k - 1] && accept_token(static_cast(i))) { + if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast(i))) { continue; } for (size_t j = 0; j < k; ++j) { - if (probabilities[i] > top_k[j] && accept_token(static_cast(i))) { + if (probabilities[i] > top_k[j] && accept_token(StaticCast(i))) { // shift elements by 1, insert the new value, move on to next value for (size_t idx = k - 1; idx > j; --idx) { top_k[idx] = top_k[idx - 1]; indices[idx] = indices[idx - 1]; } top_k[j] = probabilities[i]; - indices[j] = static_cast(i); + indices[j] = StaticCast(i); break; } }