diff --git a/.clang-format b/.clang-format index f6cb8ad..2bd8588 100644 --- a/.clang-format +++ b/.clang-format @@ -1 +1,2 @@ +Language: Cpp BasedOnStyle: Google diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 0000000..abcd9d7 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,206 @@ +FormatStyle: file +Checks: "-*,\ + abseil-*,\ + -abseil-string-find-startswith,\ + -abseil-string-find-str-contains,\ + bugprone-*,\ + -bugprone-argument-comment,\ + -bugprone-assert-side-effect,\ + -bugprone-bad-signal-to-kill-thread,\ + -bugprone-bool-pointer-implicit-conversion,\ + -bugprone-branch-clone,\ + -bugprone-copy-constructor-init,\ + -bugprone-dangling-handle,\ + -bugprone-dynamic-static-initializers,\ + -bugprone-easily-swappable-parameters,\ + -bugprone-exception-escape,\ + -bugprone-fold-init-type,\ + -bugprone-forward-declaration-namespace,\ + -bugprone-forwarding-reference-overload,\ + -bugprone-implicit-widening-of-multiplication-result,\ + -bugprone-inaccurate-erase,\ + -bugprone-incorrect-roundings,\ + -bugprone-infinite-loop,\ + -bugprone-integer-division,\ + -bugprone-lambda-function-name,\ + -bugprone-macro-parentheses,\ + -bugprone-macro-repeated-side-effects,\ + -bugprone-misplaced-operator-in-strlen-in-alloc,\ + -bugprone-misplaced-widening-cast,\ + -bugprone-move-forwarding-reference,\ + -bugprone-multiple-statement-macro,\ + -bugprone-narrowing-conversions,\ + -bugprone-no-escape,\ + -bugprone-not-null-terminated-result,\ + -bugprone-parent-virtual-call,\ + -bugprone-posix-return,\ + -bugprone-redundant-branch-condition,\ + -bugprone-reserved-identifier,\ + -bugprone-signal-handler,\ + -bugprone-signed-char-misuse,\ + -bugprone-sizeof-container,\ + -bugprone-sizeof-expression,\ + -bugprone-spuriously-wake-up-functions,\ + -bugprone-string-constructor,\ + -bugprone-string-integer-assignment,\ + -bugprone-string-literal-with-embedded-nul,\ + -bugprone-stringview-nullptr,\ + -bugprone-suspicious-enum-usage,\ + -bugprone-suspicious-include,\ + -bugprone-suspicious-memory-comparison,\ + -bugprone-suspicious-memset-usage,\ + -bugprone-suspicious-missing-comma,\ + -bugprone-suspicious-semicolon,\ + -bugprone-suspicious-string-compare,\ + -bugprone-swapped-arguments,\ + -bugprone-terminating-continue,\ + -bugprone-throw-keyword-missing,\ + -bugprone-too-small-loop-variable,\ + -bugprone-undefined-memory-manipulation,\ + -bugprone-undelegated-constructor,\ + -bugprone-unhandled-exception-at-new,\ + -bugprone-unhandled-self-assignment,\ + -bugprone-unused-raii,\ + -bugprone-unused-return-value,\ + -bugprone-use-after-move,\ + -bugprone-virtual-near-miss,\ + cert-*,\ + -cert-dcl16-c,\ + -cert-dcl21-cpp,\ + -cert-dcl37-c,\ + -cert-dcl50-cpp,\ + -cert-dcl51-cpp,\ + -cert-dcl54-cpp,\ + -cert-dcl58-cpp,\ + -cert-err33-c,\ + -cert-msc30-c,\ + -cert-msc32-c,\ + -cert-msc50-cpp,\ + -cert-msc51-cpp,\ + -cert-oop54-cpp,\ + -cert-str34-c,\ + -cert-str34-c,\ + -cert-str34-c,\ + -cert-str34-c,\ + -clang-analyzer-*,\ + concurrency-*,\ + -concurrency-mt-unsafe,\ + cppcoreguidelines-*,\ + -concurrency-mt-unsafe,\ + -cppcoreguidelines-avoid-c-arrays,\ + -cppcoreguidelines-avoid-const-or-ref-data-members,\ + -cppcoreguidelines-avoid-goto,\ + -cppcoreguidelines-avoid-magic-numbers,\ + -cppcoreguidelines-avoid-non-const-global-variables,\ + -cppcoreguidelines-c-copy-assignment-signature,\ + -cppcoreguidelines-explicit-virtual-functions,\ + -cppcoreguidelines-init-variables,\ + -cppcoreguidelines-interfaces-global-init,\ + -cppcoreguidelines-macro-usage,\ + -cppcoreguidelines-narrowing-conversions,\ + -cppcoreguidelines-no-malloc,\ + -cppcoreguidelines-non-private-member-variables-in-classes,\ + -cppcoreguidelines-owning-memory,\ + -cppcoreguidelines-prefer-member-initializer,\ + -cppcoreguidelines-pro-bounds-array-to-pointer-decay,\ + -cppcoreguidelines-pro-bounds-constant-array-index,\ + -cppcoreguidelines-pro-bounds-pointer-arithmetic,\ + -cppcoreguidelines-pro-type-const-cast,\ + -cppcoreguidelines-pro-type-member-init,\ + -cppcoreguidelines-pro-type-reinterpret-cast,\ + -cppcoreguidelines-pro-type-static-cast-downcast,\ + -cppcoreguidelines-pro-type-union-access,\ + -cppcoreguidelines-pro-type-vararg,\ + -cppcoreguidelines-slicing,\ + -cppcoreguidelines-special-member-functions,\ + -cppcoreguidelines-virtual-class-destructor,\ + google-*,\ + -google-default-arguments,\ + -google-explicit-constructor,\ + -google-readability-avoid-underscore-in-googletest-name,\ + -google-readability-braces-around-statements,\ + -google-readability-casting,\ + -google-readability-namespace-comments,\ + -google-readability-todo,\ + -google-runtime-int,\ + -google-upgrade-googletest-case,\ + misc-*,\ + -misc-misplaced-const,\ + -misc-new-delete-overloads,\ + -misc-non-private-member-variables-in-classes,\ + -misc-no-recursion,\ + -misc-redundant-expression,\ + -misc-uniqueptr-reset-release,\ + -misc-unconventional-assign-operator,\ + -misc-unused-parameters,\ + -misc-unused-using-decls,\ + modernize-*,\ + -modernize-avoid-c-arrays,\ + -modernize-concat-nested-namespaces,\ + -modernize-deprecated-headers,\ + -modernize-loop-convert,\ + -modernize-macro-to-enum,\ + -modernize-make-unique,\ + -modernize-pass-by-value,\ + -modernize-raw-string-literal,\ + -modernize-redundant-void-arg,\ + -modernize-return-braced-init-list,\ + -modernize-unary-static-assert,\ + -modernize-use-auto,\ + -modernize-use-bool-literals,\ + -modernize-use-default-member-init,\ + -modernize-use-emplace,\ + -modernize-use-equals-default,\ + -modernize-use-equals-delete,\ + -modernize-use-nodiscard,\ + -modernize-use-nullptr,\ + -modernize-use-override,\ + -modernize-use-trailing-return-type,\ + -modernize-use-transparent-functors,\ + -modernize-use-using,\ + performance-*,\ + -performance-faster-string-find,\ + -performance-for-range-copy,\ + -performance-inefficient-algorithm,\ + -performance-inefficient-string-concatenation,\ + -performance-inefficient-vector-operation,\ + -performance-move-const-arg,\ + -performance-no-automatic-move,\ + -performance-noexcept-move-constructor,\ + -performance-no-int-to-ptr,\ + -performance-trivially-destructible,\ + -performance-unnecessary-copy-initialization,\ + -performance-unnecessary-value-param,\ + portability-*,\ + readability-*,\ + -readability-avoid-const-params-in-decls,\ + -readability-braces-around-statements,\ + -readability-const-return-type,\ + -readability-container-data-pointer,\ + -readability-container-size-empty,\ + -readability-convert-member-functions-to-static,\ + -readability-else-after-return,\ + -readability-function-cognitive-complexity,\ + -readability-identifier-length,\ + -readability-implicit-bool-conversion,\ + -readability-inconsistent-declaration-parameter-name,\ + -readability-isolate-declaration,\ + -readability-magic-numbers,\ + -readability-make-member-function-const,\ + -readability-named-parameter,\ + -readability-non-const-parameter,\ + -readability-qualified-auto,\ + -readability-redundant-access-specifiers,\ + -readability-redundant-control-flow,\ + -readability-redundant-declaration,\ + -readability-redundant-member-init,\ + -readability-redundant-smartptr-get,\ + -readability-redundant-string-cstr,\ + -readability-redundant-string-init,\ + -readability-simplify-boolean-expr,\ + -readability-static-accessed-through-instance,\ + -readability-static-definition-in-anonymous-namespace,\ + -readability-suspicious-call-argument,\ + -readability-uppercase-literal-suffix,\ + -readability-use-anyofallof + " diff --git a/CMakeLists.txt b/CMakeLists.txt index 308e258..1d2a7a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ project(gemma) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) FetchContent_MakeAvailable(highway) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 588f5c6..5f11ca1 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -149,7 +149,7 @@ struct CompressTraits { } } - size_t remaining = num - i; + const size_t remaining = num - i; if (remaining != 0) { const VF in0 = hn::LoadN(df, in + i, remaining); const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2); @@ -195,7 +195,7 @@ struct CompressTraits { } } - size_t remaining = num - i; + const size_t remaining = num - i; if (remaining != 0) { const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining); const VF in0 = hn::PromoteLowerTo(df, in16); @@ -287,7 +287,7 @@ struct CompressTraits { if (COMPRESS_STATS) { for (size_t i = 0; i < num; ++i) { - tls.stats.NotifyIn(in[i] * 100 + 500); + tls.stats.NotifyIn(static_cast(lroundf(in[i] * 100.0f + 500.0f))); } const hn::Repartition dbf; @@ -358,7 +358,7 @@ HWY_NOINLINE void Compress(const float* in, size_t num, }); const double t1 = hwy::platform::Now(); - const double mb = num * sizeof(in[0]) * 1E-6; + const double mb = static_cast(num) * sizeof(in[0]) * 1E-6; const double mbps = mb / (t1 - t0); fprintf(stderr, "Compress %.1f MB/s\n", mbps); diff --git a/compression/distortion.h b/compression/distortion.h index 5fd778f..f272b52 100644 --- a/compression/distortion.h +++ b/compression/distortion.h @@ -68,15 +68,15 @@ class DistortionStats { double GeomeanValueDivL1() const { if (num_rel_ == 0) return 0.0; - return exp(sum_log_rel_ / num_rel_); + return exp(sum_log_rel_ / static_cast(num_rel_)); } double PNorm() const { // p-norms are a compromise between max-norm (penalizes the largest error // without dilution, but does not notice any other errors) and L1 (all // errors contribute, but large errors are diluted by smaller ones). - const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3); - const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6); + const double norm3 = pow(sum_pow3_ / static_cast(n_), 1.0 / 3); + const double norm6 = pow(sum_pow6_ / static_cast(n_), 1.0 / 6); return 0.5 * (norm3 + norm6); } diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index 767014a..932afd6 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -87,7 +87,7 @@ class NuqClustering { inv_len_[0] = 0.0f; // unused for (size_t i = 0; i <= kGroupSize; ++i) { - inv_len_[i] = 1.0f / i; + inv_len_[i] = 1.0f / static_cast(i); } } @@ -229,7 +229,7 @@ class NuqClustering { const float sum = cc.SumOfSorted(start, last); const int size = static_cast(last) - static_cast(start) + 1; HWY_DASSERT(0 < size && size <= kGroupSize); - centers[k] = sum / size; + centers[k] = sum / static_cast(size); // We know the range inside sorted_and_i[]; translate to original indices, // which are stored inside each of the sorted_and_i mantissas. diff --git a/gemma.cc b/gemma.cc index 9f1e4a0..eb43e81 100644 --- a/gemma.cc +++ b/gemma.cc @@ -525,7 +525,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, // In single-turn (non-chat) usage, pos and pos_offset start at 0 and are // always equal. size_t pos_offset = 0; // offset relative to pos - double prefill_start = hwy::platform::Now(); + const double prefill_start = hwy::platform::Now(); // Prefill stops before prompt.size() - 1 since the last prompt token is the // first input token for generation. @@ -547,12 +547,12 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, if (verbosity >= 2) { // in the future this output should not occur in GenerateImpl but instead // should be available as observable state for frontend code to handle I/O. - double prefill_end = hwy::platform::Now(); - const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start); + const double prefill_end = hwy::platform::Now(); + const double prefill_tok_sec = static_cast(pos_offset) / (prefill_end - prefill_start); std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n"; } - double gen_start = hwy::platform::Now(); + const double gen_start = hwy::platform::Now(); HWY_DASSERT(pos_offset == prompt.size() - 1); @@ -590,9 +590,9 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, } if (token == EOS_ID) { if (verbosity >= 2) { - double gen_end = hwy::platform::Now(); + const double gen_end = hwy::platform::Now(); const double gen_tok_sec = - (pos_offset - pos_gen_start) / (gen_end - gen_start); + static_cast(pos_offset - pos_gen_start) / (gen_end - gen_start); std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; } break; @@ -689,7 +689,7 @@ hwy::AlignedFreeUniquePtr GetCompressedWeights( if (loader.ReadAll(pool)) return c_weights_u8; // Get weights, compress, and store in cache. - hwy::AlignedUniquePtr> weights = LoadWeights(model); + const hwy::AlignedUniquePtr> weights = LoadWeights(model); Compressor compressor(pool); ForEachTensor(weights.get(), *c_weights, compressor); compressor.WriteAll(pool, cache.path.c_str()); diff --git a/ops.h b/ops.h index 7619b44..8f92d82 100644 --- a/ops.h +++ b/ops.h @@ -57,6 +57,17 @@ HWY_INLINE constexpr size_t MaxCols() { return 2048; } +template +HWY_INLINE constexpr std::enable_if_t< + std::is_arithmetic_v && std::is_arithmetic_v, To> +StaticCast(From from) noexcept { + if constexpr (std::is_unsigned_v && std::is_floating_point_v) + return static_cast( + static_cast>(from)); + else + return static_cast(from); +} + template HWY_INLINE constexpr size_t RowsPerStrip() { // Aim for 128 work items to reduce pool overhead. Must be at least one @@ -341,7 +352,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( float* HWY_RESTRICT out, size_t size) { constexpr float eps = 1e-6f; float ss = SquaredL2(x, size); - ss = 1.0f / sqrtf(ss / static_cast(size) + eps); + ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); for (size_t j = 0; j < size; j++) { // Note 1.0f centering here out[j] = (1.0f + weight[j]) * (ss * x[j]); @@ -353,7 +364,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( float* HWY_RESTRICT out, size_t size) { constexpr float eps = 1e-6f; float ss = SquaredL2(x, size); - ss = 1.0f / sqrtf(ss / static_cast(size) + eps); + ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); for (size_t j = 0; j < size; j++) { // Note 1.0f centering here out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]); @@ -364,7 +375,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) { constexpr float eps = 1e-6f; float ss = SquaredL2(inout, size); - ss = 1.0f / sqrtf(ss / static_cast(size) + eps); + ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); for (size_t j = 0; j < size; j++) { // Note 1.0f centering here inout[j] = (1.0f + weight[j]) * (ss * inout[j]); @@ -383,7 +394,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( constexpr float eps = 1e-6f; const float ss = SquaredL2(inout, size); - const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + const VF vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -411,7 +423,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( constexpr float eps = 1e-6f; const float ss = SquaredL2(x, size); - const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast(size) + eps)); + const VF vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -438,7 +451,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( constexpr float eps = 1e-6f; const float ss = SquaredL2(x, size); - const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps)); + const VF vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -459,14 +473,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( const size_t num_timescales = dim_model / 2; const float log_timescale_increment = logf(10000.0f) / - (num_timescales != 0 - ? static_cast(static_cast(num_timescales) - 1) - : 1.0f); + (num_timescales != 0 ? StaticCast(num_timescales - 1) : 1.0f); for (size_t dim = 0; dim < num_timescales; ++dim) { const float inv_timescale = - expf(static_cast(dim) * -log_timescale_increment); - x[dim] += sinf(pos * inv_timescale); - x[num_timescales + dim] += cosf(pos * inv_timescale); + expf(StaticCast(dim) * -log_timescale_increment); + x[dim] += sinf(StaticCast(pos) * inv_timescale); + x[num_timescales + dim] += cosf(StaticCast(pos) * inv_timescale); } } @@ -475,11 +487,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float freq_exponents = static_cast(2 * static_cast(dim)) / - static_cast(dim_qkv); + const float freq_exponents = + StaticCast(2 * dim) / StaticCast(dim_qkv); // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. const float timescale = powf(10000.0f, freq_exponents); - const float theta = pos / timescale; + const float theta = StaticCast(pos) / timescale; const float cos_val = cosf(theta); const float sin_val = sinf(theta); const float x0 = x[dim]; @@ -496,11 +508,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float freq_exponents = static_cast(2 * static_cast(dim)) / - static_cast(dim_qkv); + const float freq_exponents = + StaticCast(2 * dim) / StaticCast(dim_qkv); // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. const float timescale = powf(10000.0f, freq_exponents); - const float theta = pos / timescale; + const float theta = StaticCast(pos) / timescale; const float cos_val = cosf(theta); const float sin_val = sinf(theta); const float x0 = x[dim]; @@ -674,18 +686,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( std::array top_k{}; // sorted from highest [0], to lowest [k-1] std::array indices{}; for (size_t i = 0; i < vocab_size; ++i) { - if (probabilities[i] < top_k[k - 1] && accept_token(static_cast(i))) { + if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast(i))) { continue; } for (size_t j = 0; j < k; ++j) { - if (probabilities[i] > top_k[j] && accept_token(static_cast(i))) { + if (probabilities[i] > top_k[j] && accept_token(StaticCast(i))) { // shift elements by 1, insert the new value, move on to next value for (size_t idx = k - 1; idx > j; --idx) { top_k[idx] = top_k[idx - 1]; indices[idx] = indices[idx - 1]; } top_k[j] = probabilities[i]; - indices[j] = static_cast(i); + indices[j] = StaticCast(i); break; } }