mirror of https://github.com/google/gemma.cpp.git
add StaticCast
This commit is contained in:
parent
06dd013397
commit
626be6deab
53
ops.h
53
ops.h
|
|
@ -57,6 +57,16 @@ HWY_INLINE constexpr size_t MaxCols() {
|
||||||
return 2048;
|
return 2048;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename To, typename From>
|
||||||
|
HWY_INLINE constexpr std::enable_if_t<
|
||||||
|
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
|
||||||
|
StaticCast(From from) noexcept {
|
||||||
|
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
|
||||||
|
return static_cast<To>(static_cast<int64_t>(from));
|
||||||
|
else
|
||||||
|
return static_cast<To>(from);
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t kOuter>
|
template <size_t kOuter>
|
||||||
HWY_INLINE constexpr size_t RowsPerStrip() {
|
HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||||
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
||||||
|
|
@ -341,7 +351,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float eps = 1e-6f;
|
||||||
float ss = SquaredL2(x, size);
|
float ss = SquaredL2(x, size);
|
||||||
ss = 1.0f / sqrtf(ss / static_cast<float>(size) + eps);
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||||
for (size_t j = 0; j < size; j++) {
|
for (size_t j = 0; j < size; j++) {
|
||||||
// Note 1.0f centering here
|
// Note 1.0f centering here
|
||||||
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
||||||
|
|
@ -353,7 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float eps = 1e-6f;
|
||||||
float ss = SquaredL2(x, size);
|
float ss = SquaredL2(x, size);
|
||||||
ss = 1.0f / sqrtf(ss / static_cast<float>(size) + eps);
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||||
for (size_t j = 0; j < size; j++) {
|
for (size_t j = 0; j < size; j++) {
|
||||||
// Note 1.0f centering here
|
// Note 1.0f centering here
|
||||||
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
|
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
|
||||||
|
|
@ -364,7 +374,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float eps = 1e-6f;
|
||||||
float ss = SquaredL2(inout, size);
|
float ss = SquaredL2(inout, size);
|
||||||
ss = 1.0f / sqrtf(ss / static_cast<float>(size) + eps);
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||||
for (size_t j = 0; j < size; j++) {
|
for (size_t j = 0; j < size; j++) {
|
||||||
// Note 1.0f centering here
|
// Note 1.0f centering here
|
||||||
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
||||||
|
|
@ -383,7 +393,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
|
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float eps = 1e-6f;
|
||||||
const float ss = SquaredL2(inout, size);
|
const float ss = SquaredL2(inout, size);
|
||||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<float>(size) + eps));
|
const VF vss =
|
||||||
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||||
|
|
@ -411,7 +422,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float eps = 1e-6f;
|
||||||
const float ss = SquaredL2(x, size);
|
const float ss = SquaredL2(x, size);
|
||||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<float>(size) + eps));
|
const VF vss =
|
||||||
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||||
|
|
@ -438,7 +450,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float eps = 1e-6f;
|
||||||
const float ss = SquaredL2(x, size);
|
const float ss = SquaredL2(x, size);
|
||||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<float>(size) + eps));
|
const VF vss =
|
||||||
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||||
|
|
@ -459,14 +472,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
const size_t num_timescales = dim_model / 2;
|
const size_t num_timescales = dim_model / 2;
|
||||||
const float log_timescale_increment =
|
const float log_timescale_increment =
|
||||||
logf(10000.0f) /
|
logf(10000.0f) /
|
||||||
(num_timescales != 0
|
(num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
|
||||||
? static_cast<float>(static_cast<int>(num_timescales) - 1)
|
|
||||||
: 1.0f);
|
|
||||||
for (size_t dim = 0; dim < num_timescales; ++dim) {
|
for (size_t dim = 0; dim < num_timescales; ++dim) {
|
||||||
const float inv_timescale =
|
const float inv_timescale =
|
||||||
expf(static_cast<float>(dim) * -log_timescale_increment);
|
expf(StaticCast<float>(dim) * -log_timescale_increment);
|
||||||
x[dim] += sinf(static_cast<float>(pos) * inv_timescale);
|
x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
|
||||||
x[num_timescales + dim] += cosf(static_cast<float>(pos) * inv_timescale);
|
x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -475,11 +486,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||||
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
|
const float freq_exponents =
|
||||||
static_cast<float>(dim_qkv);
|
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
||||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||||
const float timescale = powf(10000.0f, freq_exponents);
|
const float timescale = powf(10000.0f, freq_exponents);
|
||||||
const float theta = static_cast<float>(pos) / timescale;
|
const float theta = StaticCast<float>(pos) / timescale;
|
||||||
const float cos_val = cosf(theta);
|
const float cos_val = cosf(theta);
|
||||||
const float sin_val = sinf(theta);
|
const float sin_val = sinf(theta);
|
||||||
const float x0 = x[dim];
|
const float x0 = x[dim];
|
||||||
|
|
@ -496,11 +507,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||||
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
|
const float freq_exponents =
|
||||||
static_cast<float>(dim_qkv);
|
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
||||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||||
const float timescale = powf(10000.0f, freq_exponents);
|
const float timescale = powf(10000.0f, freq_exponents);
|
||||||
const float theta = static_cast<float>(pos) / timescale;
|
const float theta = StaticCast<float>(pos) / timescale;
|
||||||
const float cos_val = cosf(theta);
|
const float cos_val = cosf(theta);
|
||||||
const float sin_val = sinf(theta);
|
const float sin_val = sinf(theta);
|
||||||
const float x0 = x[dim];
|
const float x0 = x[dim];
|
||||||
|
|
@ -674,18 +685,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||||
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
|
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
|
||||||
std::array<int, k> indices{};
|
std::array<int, k> indices{};
|
||||||
for (size_t i = 0; i < vocab_size; ++i) {
|
for (size_t i = 0; i < vocab_size; ++i) {
|
||||||
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) {
|
if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast<int>(i))) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (size_t j = 0; j < k; ++j) {
|
for (size_t j = 0; j < k; ++j) {
|
||||||
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) {
|
if (probabilities[i] > top_k[j] && accept_token(StaticCast<int>(i))) {
|
||||||
// shift elements by 1, insert the new value, move on to next value
|
// shift elements by 1, insert the new value, move on to next value
|
||||||
for (size_t idx = k - 1; idx > j; --idx) {
|
for (size_t idx = k - 1; idx > j; --idx) {
|
||||||
top_k[idx] = top_k[idx - 1];
|
top_k[idx] = top_k[idx - 1];
|
||||||
indices[idx] = indices[idx - 1];
|
indices[idx] = indices[idx - 1];
|
||||||
}
|
}
|
||||||
top_k[j] = probabilities[i];
|
top_k[j] = probabilities[i];
|
||||||
indices[j] = static_cast<int>(i);
|
indices[j] = StaticCast<int>(i);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue