Merge pull request #65 from enum-class:narrowing-issues

PiperOrigin-RevId: 612279564
This commit is contained in:
Copybara-Service 2024-03-03 18:51:59 -08:00
commit cd7468199c
8 changed files with 257 additions and 37 deletions

View File

@ -1 +1,2 @@
Language: Cpp
BasedOnStyle: Google

206
.clang-tidy Normal file
View File

@ -0,0 +1,206 @@
FormatStyle: file
Checks: "-*,\
abseil-*,\
-abseil-string-find-startswith,\
-abseil-string-find-str-contains,\
bugprone-*,\
-bugprone-argument-comment,\
-bugprone-assert-side-effect,\
-bugprone-bad-signal-to-kill-thread,\
-bugprone-bool-pointer-implicit-conversion,\
-bugprone-branch-clone,\
-bugprone-copy-constructor-init,\
-bugprone-dangling-handle,\
-bugprone-dynamic-static-initializers,\
-bugprone-easily-swappable-parameters,\
-bugprone-exception-escape,\
-bugprone-fold-init-type,\
-bugprone-forward-declaration-namespace,\
-bugprone-forwarding-reference-overload,\
-bugprone-implicit-widening-of-multiplication-result,\
-bugprone-inaccurate-erase,\
-bugprone-incorrect-roundings,\
-bugprone-infinite-loop,\
-bugprone-integer-division,\
-bugprone-lambda-function-name,\
-bugprone-macro-parentheses,\
-bugprone-macro-repeated-side-effects,\
-bugprone-misplaced-operator-in-strlen-in-alloc,\
-bugprone-misplaced-widening-cast,\
-bugprone-move-forwarding-reference,\
-bugprone-multiple-statement-macro,\
-bugprone-narrowing-conversions,\
-bugprone-no-escape,\
-bugprone-not-null-terminated-result,\
-bugprone-parent-virtual-call,\
-bugprone-posix-return,\
-bugprone-redundant-branch-condition,\
-bugprone-reserved-identifier,\
-bugprone-signal-handler,\
-bugprone-signed-char-misuse,\
-bugprone-sizeof-container,\
-bugprone-sizeof-expression,\
-bugprone-spuriously-wake-up-functions,\
-bugprone-string-constructor,\
-bugprone-string-integer-assignment,\
-bugprone-string-literal-with-embedded-nul,\
-bugprone-stringview-nullptr,\
-bugprone-suspicious-enum-usage,\
-bugprone-suspicious-include,\
-bugprone-suspicious-memory-comparison,\
-bugprone-suspicious-memset-usage,\
-bugprone-suspicious-missing-comma,\
-bugprone-suspicious-semicolon,\
-bugprone-suspicious-string-compare,\
-bugprone-swapped-arguments,\
-bugprone-terminating-continue,\
-bugprone-throw-keyword-missing,\
-bugprone-too-small-loop-variable,\
-bugprone-undefined-memory-manipulation,\
-bugprone-undelegated-constructor,\
-bugprone-unhandled-exception-at-new,\
-bugprone-unhandled-self-assignment,\
-bugprone-unused-raii,\
-bugprone-unused-return-value,\
-bugprone-use-after-move,\
-bugprone-virtual-near-miss,\
cert-*,\
-cert-dcl16-c,\
-cert-dcl21-cpp,\
-cert-dcl37-c,\
-cert-dcl50-cpp,\
-cert-dcl51-cpp,\
-cert-dcl54-cpp,\
-cert-dcl58-cpp,\
-cert-err33-c,\
-cert-msc30-c,\
-cert-msc32-c,\
-cert-msc50-cpp,\
-cert-msc51-cpp,\
-cert-oop54-cpp,\
-cert-str34-c,\
-cert-str34-c,\
-cert-str34-c,\
-cert-str34-c,\
-clang-analyzer-*,\
concurrency-*,\
-concurrency-mt-unsafe,\
cppcoreguidelines-*,\
-concurrency-mt-unsafe,\
-cppcoreguidelines-avoid-c-arrays,\
-cppcoreguidelines-avoid-const-or-ref-data-members,\
-cppcoreguidelines-avoid-goto,\
-cppcoreguidelines-avoid-magic-numbers,\
-cppcoreguidelines-avoid-non-const-global-variables,\
-cppcoreguidelines-c-copy-assignment-signature,\
-cppcoreguidelines-explicit-virtual-functions,\
-cppcoreguidelines-init-variables,\
-cppcoreguidelines-interfaces-global-init,\
-cppcoreguidelines-macro-usage,\
-cppcoreguidelines-narrowing-conversions,\
-cppcoreguidelines-no-malloc,\
-cppcoreguidelines-non-private-member-variables-in-classes,\
-cppcoreguidelines-owning-memory,\
-cppcoreguidelines-prefer-member-initializer,\
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,\
-cppcoreguidelines-pro-bounds-constant-array-index,\
-cppcoreguidelines-pro-bounds-pointer-arithmetic,\
-cppcoreguidelines-pro-type-const-cast,\
-cppcoreguidelines-pro-type-member-init,\
-cppcoreguidelines-pro-type-reinterpret-cast,\
-cppcoreguidelines-pro-type-static-cast-downcast,\
-cppcoreguidelines-pro-type-union-access,\
-cppcoreguidelines-pro-type-vararg,\
-cppcoreguidelines-slicing,\
-cppcoreguidelines-special-member-functions,\
-cppcoreguidelines-virtual-class-destructor,\
google-*,\
-google-default-arguments,\
-google-explicit-constructor,\
-google-readability-avoid-underscore-in-googletest-name,\
-google-readability-braces-around-statements,\
-google-readability-casting,\
-google-readability-namespace-comments,\
-google-readability-todo,\
-google-runtime-int,\
-google-upgrade-googletest-case,\
misc-*,\
-misc-misplaced-const,\
-misc-new-delete-overloads,\
-misc-non-private-member-variables-in-classes,\
-misc-no-recursion,\
-misc-redundant-expression,\
-misc-uniqueptr-reset-release,\
-misc-unconventional-assign-operator,\
-misc-unused-parameters,\
-misc-unused-using-decls,\
modernize-*,\
-modernize-avoid-c-arrays,\
-modernize-concat-nested-namespaces,\
-modernize-deprecated-headers,\
-modernize-loop-convert,\
-modernize-macro-to-enum,\
-modernize-make-unique,\
-modernize-pass-by-value,\
-modernize-raw-string-literal,\
-modernize-redundant-void-arg,\
-modernize-return-braced-init-list,\
-modernize-unary-static-assert,\
-modernize-use-auto,\
-modernize-use-bool-literals,\
-modernize-use-default-member-init,\
-modernize-use-emplace,\
-modernize-use-equals-default,\
-modernize-use-equals-delete,\
-modernize-use-nodiscard,\
-modernize-use-nullptr,\
-modernize-use-override,\
-modernize-use-trailing-return-type,\
-modernize-use-transparent-functors,\
-modernize-use-using,\
performance-*,\
-performance-faster-string-find,\
-performance-for-range-copy,\
-performance-inefficient-algorithm,\
-performance-inefficient-string-concatenation,\
-performance-inefficient-vector-operation,\
-performance-move-const-arg,\
-performance-no-automatic-move,\
-performance-noexcept-move-constructor,\
-performance-no-int-to-ptr,\
-performance-trivially-destructible,\
-performance-unnecessary-copy-initialization,\
-performance-unnecessary-value-param,\
portability-*,\
readability-*,\
-readability-avoid-const-params-in-decls,\
-readability-braces-around-statements,\
-readability-const-return-type,\
-readability-container-data-pointer,\
-readability-container-size-empty,\
-readability-convert-member-functions-to-static,\
-readability-else-after-return,\
-readability-function-cognitive-complexity,\
-readability-identifier-length,\
-readability-implicit-bool-conversion,\
-readability-inconsistent-declaration-parameter-name,\
-readability-isolate-declaration,\
-readability-magic-numbers,\
-readability-make-member-function-const,\
-readability-named-parameter,\
-readability-non-const-parameter,\
-readability-qualified-auto,\
-readability-redundant-access-specifiers,\
-readability-redundant-control-flow,\
-readability-redundant-declaration,\
-readability-redundant-member-init,\
-readability-redundant-smartptr-get,\
-readability-redundant-string-cstr,\
-readability-redundant-string-init,\
-readability-simplify-boolean-expr,\
-readability-static-accessed-through-instance,\
-readability-static-definition-in-anonymous-namespace,\
-readability-suspicious-call-argument,\
-readability-uppercase-literal-suffix,\
-readability-use-anyofallof
"

View File

@ -20,6 +20,7 @@ project(gemma)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
FetchContent_MakeAvailable(highway)

View File

@ -149,7 +149,7 @@ struct CompressTraits<hwy::bfloat16_t> {
}
}
size_t remaining = num - i;
const size_t remaining = num - i;
if (remaining != 0) {
const VF in0 = hn::LoadN(df, in + i, remaining);
const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2);
@ -195,7 +195,7 @@ struct CompressTraits<hwy::bfloat16_t> {
}
}
size_t remaining = num - i;
const size_t remaining = num - i;
if (remaining != 0) {
const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining);
const VF in0 = hn::PromoteLowerTo(df, in16);
@ -287,7 +287,7 @@ struct CompressTraits<NuqStream> {
if (COMPRESS_STATS) {
for (size_t i = 0; i < num; ++i) {
tls.stats.NotifyIn(in[i] * 100 + 500);
tls.stats.NotifyIn(static_cast<int>(lroundf(in[i] * 100.0f + 500.0f)));
}
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
@ -358,7 +358,7 @@ HWY_NOINLINE void Compress(const float* in, size_t num,
});
const double t1 = hwy::platform::Now();
const double mb = num * sizeof(in[0]) * 1E-6;
const double mb = static_cast<double>(num) * sizeof(in[0]) * 1E-6;
const double mbps = mb / (t1 - t0);
fprintf(stderr, "Compress %.1f MB/s\n", mbps);

View File

@ -68,15 +68,15 @@ class DistortionStats {
double GeomeanValueDivL1() const {
if (num_rel_ == 0) return 0.0;
return exp(sum_log_rel_ / num_rel_);
return exp(sum_log_rel_ / static_cast<double>(num_rel_));
}
double PNorm() const {
// p-norms are a compromise between max-norm (penalizes the largest error
// without dilution, but does not notice any other errors) and L1 (all
// errors contribute, but large errors are diluted by smaller ones).
const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3);
const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6);
const double norm3 = pow(sum_pow3_ / static_cast<double>(n_), 1.0 / 3);
const double norm6 = pow(sum_pow6_ / static_cast<double>(n_), 1.0 / 6);
return 0.5 * (norm3 + norm6);
}

View File

@ -87,7 +87,7 @@ class NuqClustering {
inv_len_[0] = 0.0f; // unused
for (size_t i = 0; i <= kGroupSize; ++i) {
inv_len_[i] = 1.0f / i;
inv_len_[i] = 1.0f / static_cast<float>(i);
}
}
@ -229,7 +229,7 @@ class NuqClustering {
const float sum = cc.SumOfSorted(start, last);
const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
HWY_DASSERT(0 < size && size <= kGroupSize);
centers[k] = sum / size;
centers[k] = sum / static_cast<float>(size);
// We know the range inside sorted_and_i[]; translate to original indices,
// which are stored inside each of the sorted_and_i mantissas.

View File

@ -525,7 +525,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
// always equal.
size_t pos_offset = 0; // offset relative to pos
double prefill_start = hwy::platform::Now();
const double prefill_start = hwy::platform::Now();
// Prefill stops before prompt.size() - 1 since the last prompt token is the
// first input token for generation.
@ -547,12 +547,12 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O.
double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n";
}
double gen_start = hwy::platform::Now();
const double gen_start = hwy::platform::Now();
HWY_DASSERT(pos_offset == prompt.size() - 1);
@ -590,9 +590,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
}
if (token == EOS_ID) {
if (verbosity >= 2) {
double gen_end = hwy::platform::Now();
const double gen_end = hwy::platform::Now();
const double gen_tok_sec =
(pos_offset - pos_gen_start) / (gen_end - gen_start);
static_cast<double>(pos_offset - pos_gen_start) / (gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
}
break;
@ -689,7 +689,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
if (loader.ReadAll(pool)) return c_weights_u8;
// Get weights, compress, and store in cache.
hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
const hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
Compressor compressor(pool);
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
compressor.WriteAll(pool, cache.path.c_str());

54
ops.h
View File

@ -57,6 +57,17 @@ HWY_INLINE constexpr size_t MaxCols() {
return 2048;
}
template <typename To, typename From>
HWY_INLINE constexpr std::enable_if_t<
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
StaticCast(From from) noexcept {
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
return static_cast<To>(
static_cast<hwy::SignedFromSize<sizeof(From)>>(from));
else
return static_cast<To>(from);
}
template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one
@ -341,7 +352,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
out[j] = (1.0f + weight[j]) * (ss * x[j]);
@ -353,7 +364,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
@ -364,7 +375,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(inout, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
@ -383,7 +394,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(inout, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -411,7 +423,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -438,7 +451,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps));
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -459,14 +473,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
const size_t num_timescales = dim_model / 2;
const float log_timescale_increment =
logf(10000.0f) /
(num_timescales != 0
? static_cast<float>(static_cast<int>(num_timescales) - 1)
: 1.0f);
(num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
for (size_t dim = 0; dim < num_timescales; ++dim) {
const float inv_timescale =
expf(static_cast<int>(dim) * -log_timescale_increment);
x[dim] += sinf(pos * inv_timescale);
x[num_timescales + dim] += cosf(pos * inv_timescale);
expf(StaticCast<float>(dim) * -log_timescale_increment);
x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
}
}
@ -475,11 +487,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
static_cast<float>(dim_qkv);
const float freq_exponents =
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale;
const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
@ -496,11 +508,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
static_cast<float>(dim_qkv);
const float freq_exponents =
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale;
const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
@ -674,18 +686,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) {
if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast<int>(i))) {
continue;
}
for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) {
if (probabilities[i] > top_k[j] && accept_token(StaticCast<int>(i))) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = probabilities[i];
indices[j] = static_cast<int>(i);
indices[j] = StaticCast<int>(i);
break;
}
}