mirror of https://github.com/google/gemma.cpp.git
Merge pull request #65 from enum-class:narrowing-issues
PiperOrigin-RevId: 612279564
This commit is contained in:
commit
cd7468199c
|
|
@ -1 +1,2 @@
|
|||
Language: Cpp
|
||||
BasedOnStyle: Google
|
||||
|
|
|
|||
|
|
@ -0,0 +1,206 @@
|
|||
FormatStyle: file
|
||||
Checks: "-*,\
|
||||
abseil-*,\
|
||||
-abseil-string-find-startswith,\
|
||||
-abseil-string-find-str-contains,\
|
||||
bugprone-*,\
|
||||
-bugprone-argument-comment,\
|
||||
-bugprone-assert-side-effect,\
|
||||
-bugprone-bad-signal-to-kill-thread,\
|
||||
-bugprone-bool-pointer-implicit-conversion,\
|
||||
-bugprone-branch-clone,\
|
||||
-bugprone-copy-constructor-init,\
|
||||
-bugprone-dangling-handle,\
|
||||
-bugprone-dynamic-static-initializers,\
|
||||
-bugprone-easily-swappable-parameters,\
|
||||
-bugprone-exception-escape,\
|
||||
-bugprone-fold-init-type,\
|
||||
-bugprone-forward-declaration-namespace,\
|
||||
-bugprone-forwarding-reference-overload,\
|
||||
-bugprone-implicit-widening-of-multiplication-result,\
|
||||
-bugprone-inaccurate-erase,\
|
||||
-bugprone-incorrect-roundings,\
|
||||
-bugprone-infinite-loop,\
|
||||
-bugprone-integer-division,\
|
||||
-bugprone-lambda-function-name,\
|
||||
-bugprone-macro-parentheses,\
|
||||
-bugprone-macro-repeated-side-effects,\
|
||||
-bugprone-misplaced-operator-in-strlen-in-alloc,\
|
||||
-bugprone-misplaced-widening-cast,\
|
||||
-bugprone-move-forwarding-reference,\
|
||||
-bugprone-multiple-statement-macro,\
|
||||
-bugprone-narrowing-conversions,\
|
||||
-bugprone-no-escape,\
|
||||
-bugprone-not-null-terminated-result,\
|
||||
-bugprone-parent-virtual-call,\
|
||||
-bugprone-posix-return,\
|
||||
-bugprone-redundant-branch-condition,\
|
||||
-bugprone-reserved-identifier,\
|
||||
-bugprone-signal-handler,\
|
||||
-bugprone-signed-char-misuse,\
|
||||
-bugprone-sizeof-container,\
|
||||
-bugprone-sizeof-expression,\
|
||||
-bugprone-spuriously-wake-up-functions,\
|
||||
-bugprone-string-constructor,\
|
||||
-bugprone-string-integer-assignment,\
|
||||
-bugprone-string-literal-with-embedded-nul,\
|
||||
-bugprone-stringview-nullptr,\
|
||||
-bugprone-suspicious-enum-usage,\
|
||||
-bugprone-suspicious-include,\
|
||||
-bugprone-suspicious-memory-comparison,\
|
||||
-bugprone-suspicious-memset-usage,\
|
||||
-bugprone-suspicious-missing-comma,\
|
||||
-bugprone-suspicious-semicolon,\
|
||||
-bugprone-suspicious-string-compare,\
|
||||
-bugprone-swapped-arguments,\
|
||||
-bugprone-terminating-continue,\
|
||||
-bugprone-throw-keyword-missing,\
|
||||
-bugprone-too-small-loop-variable,\
|
||||
-bugprone-undefined-memory-manipulation,\
|
||||
-bugprone-undelegated-constructor,\
|
||||
-bugprone-unhandled-exception-at-new,\
|
||||
-bugprone-unhandled-self-assignment,\
|
||||
-bugprone-unused-raii,\
|
||||
-bugprone-unused-return-value,\
|
||||
-bugprone-use-after-move,\
|
||||
-bugprone-virtual-near-miss,\
|
||||
cert-*,\
|
||||
-cert-dcl16-c,\
|
||||
-cert-dcl21-cpp,\
|
||||
-cert-dcl37-c,\
|
||||
-cert-dcl50-cpp,\
|
||||
-cert-dcl51-cpp,\
|
||||
-cert-dcl54-cpp,\
|
||||
-cert-dcl58-cpp,\
|
||||
-cert-err33-c,\
|
||||
-cert-msc30-c,\
|
||||
-cert-msc32-c,\
|
||||
-cert-msc50-cpp,\
|
||||
-cert-msc51-cpp,\
|
||||
-cert-oop54-cpp,\
|
||||
-cert-str34-c,\
|
||||
-cert-str34-c,\
|
||||
-cert-str34-c,\
|
||||
-cert-str34-c,\
|
||||
-clang-analyzer-*,\
|
||||
concurrency-*,\
|
||||
-concurrency-mt-unsafe,\
|
||||
cppcoreguidelines-*,\
|
||||
-concurrency-mt-unsafe,\
|
||||
-cppcoreguidelines-avoid-c-arrays,\
|
||||
-cppcoreguidelines-avoid-const-or-ref-data-members,\
|
||||
-cppcoreguidelines-avoid-goto,\
|
||||
-cppcoreguidelines-avoid-magic-numbers,\
|
||||
-cppcoreguidelines-avoid-non-const-global-variables,\
|
||||
-cppcoreguidelines-c-copy-assignment-signature,\
|
||||
-cppcoreguidelines-explicit-virtual-functions,\
|
||||
-cppcoreguidelines-init-variables,\
|
||||
-cppcoreguidelines-interfaces-global-init,\
|
||||
-cppcoreguidelines-macro-usage,\
|
||||
-cppcoreguidelines-narrowing-conversions,\
|
||||
-cppcoreguidelines-no-malloc,\
|
||||
-cppcoreguidelines-non-private-member-variables-in-classes,\
|
||||
-cppcoreguidelines-owning-memory,\
|
||||
-cppcoreguidelines-prefer-member-initializer,\
|
||||
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,\
|
||||
-cppcoreguidelines-pro-bounds-constant-array-index,\
|
||||
-cppcoreguidelines-pro-bounds-pointer-arithmetic,\
|
||||
-cppcoreguidelines-pro-type-const-cast,\
|
||||
-cppcoreguidelines-pro-type-member-init,\
|
||||
-cppcoreguidelines-pro-type-reinterpret-cast,\
|
||||
-cppcoreguidelines-pro-type-static-cast-downcast,\
|
||||
-cppcoreguidelines-pro-type-union-access,\
|
||||
-cppcoreguidelines-pro-type-vararg,\
|
||||
-cppcoreguidelines-slicing,\
|
||||
-cppcoreguidelines-special-member-functions,\
|
||||
-cppcoreguidelines-virtual-class-destructor,\
|
||||
google-*,\
|
||||
-google-default-arguments,\
|
||||
-google-explicit-constructor,\
|
||||
-google-readability-avoid-underscore-in-googletest-name,\
|
||||
-google-readability-braces-around-statements,\
|
||||
-google-readability-casting,\
|
||||
-google-readability-namespace-comments,\
|
||||
-google-readability-todo,\
|
||||
-google-runtime-int,\
|
||||
-google-upgrade-googletest-case,\
|
||||
misc-*,\
|
||||
-misc-misplaced-const,\
|
||||
-misc-new-delete-overloads,\
|
||||
-misc-non-private-member-variables-in-classes,\
|
||||
-misc-no-recursion,\
|
||||
-misc-redundant-expression,\
|
||||
-misc-uniqueptr-reset-release,\
|
||||
-misc-unconventional-assign-operator,\
|
||||
-misc-unused-parameters,\
|
||||
-misc-unused-using-decls,\
|
||||
modernize-*,\
|
||||
-modernize-avoid-c-arrays,\
|
||||
-modernize-concat-nested-namespaces,\
|
||||
-modernize-deprecated-headers,\
|
||||
-modernize-loop-convert,\
|
||||
-modernize-macro-to-enum,\
|
||||
-modernize-make-unique,\
|
||||
-modernize-pass-by-value,\
|
||||
-modernize-raw-string-literal,\
|
||||
-modernize-redundant-void-arg,\
|
||||
-modernize-return-braced-init-list,\
|
||||
-modernize-unary-static-assert,\
|
||||
-modernize-use-auto,\
|
||||
-modernize-use-bool-literals,\
|
||||
-modernize-use-default-member-init,\
|
||||
-modernize-use-emplace,\
|
||||
-modernize-use-equals-default,\
|
||||
-modernize-use-equals-delete,\
|
||||
-modernize-use-nodiscard,\
|
||||
-modernize-use-nullptr,\
|
||||
-modernize-use-override,\
|
||||
-modernize-use-trailing-return-type,\
|
||||
-modernize-use-transparent-functors,\
|
||||
-modernize-use-using,\
|
||||
performance-*,\
|
||||
-performance-faster-string-find,\
|
||||
-performance-for-range-copy,\
|
||||
-performance-inefficient-algorithm,\
|
||||
-performance-inefficient-string-concatenation,\
|
||||
-performance-inefficient-vector-operation,\
|
||||
-performance-move-const-arg,\
|
||||
-performance-no-automatic-move,\
|
||||
-performance-noexcept-move-constructor,\
|
||||
-performance-no-int-to-ptr,\
|
||||
-performance-trivially-destructible,\
|
||||
-performance-unnecessary-copy-initialization,\
|
||||
-performance-unnecessary-value-param,\
|
||||
portability-*,\
|
||||
readability-*,\
|
||||
-readability-avoid-const-params-in-decls,\
|
||||
-readability-braces-around-statements,\
|
||||
-readability-const-return-type,\
|
||||
-readability-container-data-pointer,\
|
||||
-readability-container-size-empty,\
|
||||
-readability-convert-member-functions-to-static,\
|
||||
-readability-else-after-return,\
|
||||
-readability-function-cognitive-complexity,\
|
||||
-readability-identifier-length,\
|
||||
-readability-implicit-bool-conversion,\
|
||||
-readability-inconsistent-declaration-parameter-name,\
|
||||
-readability-isolate-declaration,\
|
||||
-readability-magic-numbers,\
|
||||
-readability-make-member-function-const,\
|
||||
-readability-named-parameter,\
|
||||
-readability-non-const-parameter,\
|
||||
-readability-qualified-auto,\
|
||||
-readability-redundant-access-specifiers,\
|
||||
-readability-redundant-control-flow,\
|
||||
-readability-redundant-declaration,\
|
||||
-readability-redundant-member-init,\
|
||||
-readability-redundant-smartptr-get,\
|
||||
-readability-redundant-string-cstr,\
|
||||
-readability-redundant-string-init,\
|
||||
-readability-simplify-boolean-expr,\
|
||||
-readability-static-accessed-through-instance,\
|
||||
-readability-static-definition-in-anonymous-namespace,\
|
||||
-readability-suspicious-call-argument,\
|
||||
-readability-uppercase-literal-suffix,\
|
||||
-readability-use-anyofallof
|
||||
"
|
||||
|
|
@ -20,6 +20,7 @@ project(gemma)
|
|||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
}
|
||||
}
|
||||
|
||||
size_t remaining = num - i;
|
||||
const size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const VF in0 = hn::LoadN(df, in + i, remaining);
|
||||
const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2);
|
||||
|
|
@ -195,7 +195,7 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
}
|
||||
}
|
||||
|
||||
size_t remaining = num - i;
|
||||
const size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining);
|
||||
const VF in0 = hn::PromoteLowerTo(df, in16);
|
||||
|
|
@ -287,7 +287,7 @@ struct CompressTraits<NuqStream> {
|
|||
|
||||
if (COMPRESS_STATS) {
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
tls.stats.NotifyIn(in[i] * 100 + 500);
|
||||
tls.stats.NotifyIn(static_cast<int>(lroundf(in[i] * 100.0f + 500.0f)));
|
||||
}
|
||||
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
|
|
@ -358,7 +358,7 @@ HWY_NOINLINE void Compress(const float* in, size_t num,
|
|||
});
|
||||
|
||||
const double t1 = hwy::platform::Now();
|
||||
const double mb = num * sizeof(in[0]) * 1E-6;
|
||||
const double mb = static_cast<double>(num) * sizeof(in[0]) * 1E-6;
|
||||
const double mbps = mb / (t1 - t0);
|
||||
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
|
||||
|
||||
|
|
|
|||
|
|
@ -68,15 +68,15 @@ class DistortionStats {
|
|||
|
||||
double GeomeanValueDivL1() const {
|
||||
if (num_rel_ == 0) return 0.0;
|
||||
return exp(sum_log_rel_ / num_rel_);
|
||||
return exp(sum_log_rel_ / static_cast<double>(num_rel_));
|
||||
}
|
||||
|
||||
double PNorm() const {
|
||||
// p-norms are a compromise between max-norm (penalizes the largest error
|
||||
// without dilution, but does not notice any other errors) and L1 (all
|
||||
// errors contribute, but large errors are diluted by smaller ones).
|
||||
const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3);
|
||||
const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6);
|
||||
const double norm3 = pow(sum_pow3_ / static_cast<double>(n_), 1.0 / 3);
|
||||
const double norm6 = pow(sum_pow6_ / static_cast<double>(n_), 1.0 / 6);
|
||||
return 0.5 * (norm3 + norm6);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class NuqClustering {
|
|||
|
||||
inv_len_[0] = 0.0f; // unused
|
||||
for (size_t i = 0; i <= kGroupSize; ++i) {
|
||||
inv_len_[i] = 1.0f / i;
|
||||
inv_len_[i] = 1.0f / static_cast<float>(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ class NuqClustering {
|
|||
const float sum = cc.SumOfSorted(start, last);
|
||||
const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
|
||||
HWY_DASSERT(0 < size && size <= kGroupSize);
|
||||
centers[k] = sum / size;
|
||||
centers[k] = sum / static_cast<float>(size);
|
||||
|
||||
// We know the range inside sorted_and_i[]; translate to original indices,
|
||||
// which are stored inside each of the sorted_and_i mantissas.
|
||||
|
|
|
|||
14
gemma.cc
14
gemma.cc
|
|
@ -525,7 +525,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
|
||||
// always equal.
|
||||
size_t pos_offset = 0; // offset relative to pos
|
||||
double prefill_start = hwy::platform::Now();
|
||||
const double prefill_start = hwy::platform::Now();
|
||||
|
||||
// Prefill stops before prompt.size() - 1 since the last prompt token is the
|
||||
// first input token for generation.
|
||||
|
|
@ -547,12 +547,12 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
if (verbosity >= 2) {
|
||||
// in the future this output should not occur in GenerateImpl but instead
|
||||
// should be available as observable state for frontend code to handle I/O.
|
||||
double prefill_end = hwy::platform::Now();
|
||||
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
|
||||
const double prefill_end = hwy::platform::Now();
|
||||
const double prefill_tok_sec = static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n";
|
||||
}
|
||||
|
||||
double gen_start = hwy::platform::Now();
|
||||
const double gen_start = hwy::platform::Now();
|
||||
|
||||
HWY_DASSERT(pos_offset == prompt.size() - 1);
|
||||
|
||||
|
|
@ -590,9 +590,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
|
|||
}
|
||||
if (token == EOS_ID) {
|
||||
if (verbosity >= 2) {
|
||||
double gen_end = hwy::platform::Now();
|
||||
const double gen_end = hwy::platform::Now();
|
||||
const double gen_tok_sec =
|
||||
(pos_offset - pos_gen_start) / (gen_end - gen_start);
|
||||
static_cast<double>(pos_offset - pos_gen_start) / (gen_end - gen_start);
|
||||
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
||||
}
|
||||
break;
|
||||
|
|
@ -689,7 +689,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
|||
if (loader.ReadAll(pool)) return c_weights_u8;
|
||||
|
||||
// Get weights, compress, and store in cache.
|
||||
hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
|
||||
const hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
|
||||
Compressor compressor(pool);
|
||||
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
|
||||
compressor.WriteAll(pool, cache.path.c_str());
|
||||
|
|
|
|||
54
ops.h
54
ops.h
|
|
@ -57,6 +57,17 @@ HWY_INLINE constexpr size_t MaxCols() {
|
|||
return 2048;
|
||||
}
|
||||
|
||||
template <typename To, typename From>
|
||||
HWY_INLINE constexpr std::enable_if_t<
|
||||
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
|
||||
StaticCast(From from) noexcept {
|
||||
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
|
||||
return static_cast<To>(
|
||||
static_cast<hwy::SignedFromSize<sizeof(From)>>(from));
|
||||
else
|
||||
return static_cast<To>(from);
|
||||
}
|
||||
|
||||
template <size_t kOuter>
|
||||
HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
||||
|
|
@ -341,7 +352,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
float* HWY_RESTRICT out, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(x, size);
|
||||
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
|
||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
||||
|
|
@ -353,7 +364,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
float* HWY_RESTRICT out, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(x, size);
|
||||
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
|
||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
|
||||
|
|
@ -364,7 +375,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|||
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(inout, size);
|
||||
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
|
||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
||||
|
|
@ -383,7 +394,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(inout, size);
|
||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
const VF vss =
|
||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
|
|
@ -411,7 +423,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(x, size);
|
||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
const VF vss =
|
||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
|
|
@ -438,7 +451,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(x, size);
|
||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps));
|
||||
const VF vss =
|
||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
|
|
@ -459,14 +473,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
const size_t num_timescales = dim_model / 2;
|
||||
const float log_timescale_increment =
|
||||
logf(10000.0f) /
|
||||
(num_timescales != 0
|
||||
? static_cast<float>(static_cast<int>(num_timescales) - 1)
|
||||
: 1.0f);
|
||||
(num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
|
||||
for (size_t dim = 0; dim < num_timescales; ++dim) {
|
||||
const float inv_timescale =
|
||||
expf(static_cast<int>(dim) * -log_timescale_increment);
|
||||
x[dim] += sinf(pos * inv_timescale);
|
||||
x[num_timescales + dim] += cosf(pos * inv_timescale);
|
||||
expf(StaticCast<float>(dim) * -log_timescale_increment);
|
||||
x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
|
||||
x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -475,11 +487,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
|||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
|
||||
static_cast<float>(dim_qkv);
|
||||
const float freq_exponents =
|
||||
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||
const float timescale = powf(10000.0f, freq_exponents);
|
||||
const float theta = pos / timescale;
|
||||
const float theta = StaticCast<float>(pos) / timescale;
|
||||
const float cos_val = cosf(theta);
|
||||
const float sin_val = sinf(theta);
|
||||
const float x0 = x[dim];
|
||||
|
|
@ -496,11 +508,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
|
|||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
|
||||
static_cast<float>(dim_qkv);
|
||||
const float freq_exponents =
|
||||
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||
const float timescale = powf(10000.0f, freq_exponents);
|
||||
const float theta = pos / timescale;
|
||||
const float theta = StaticCast<float>(pos) / timescale;
|
||||
const float cos_val = cosf(theta);
|
||||
const float sin_val = sinf(theta);
|
||||
const float x0 = x[dim];
|
||||
|
|
@ -674,18 +686,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
|||
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
|
||||
std::array<int, k> indices{};
|
||||
for (size_t i = 0; i < vocab_size; ++i) {
|
||||
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) {
|
||||
if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast<int>(i))) {
|
||||
continue;
|
||||
}
|
||||
for (size_t j = 0; j < k; ++j) {
|
||||
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) {
|
||||
if (probabilities[i] > top_k[j] && accept_token(StaticCast<int>(i))) {
|
||||
// shift elements by 1, insert the new value, move on to next value
|
||||
for (size_t idx = k - 1; idx > j; --idx) {
|
||||
top_k[idx] = top_k[idx - 1];
|
||||
indices[idx] = indices[idx - 1];
|
||||
}
|
||||
top_k[j] = probabilities[i];
|
||||
indices[j] = static_cast<int>(i);
|
||||
indices[j] = StaticCast<int>(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue