From e588a7f45d474c2fdea4e9facfe028707a133067 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 1 Jul 2024 09:44:41 -0700 Subject: [PATCH] Add config for att/final cap, skip max-subtract. Fixes #278 Also update includes/deps for backprop/. PiperOrigin-RevId: 648399222 --- BUILD.bazel | 3 +++ backprop/backward-inl.h | 36 +++++++++++--------------------- backprop/backward_scalar.h | 27 +++++++++--------------- backprop/backward_scalar_test.cc | 24 ++++++++++++--------- backprop/backward_test.cc | 9 +++----- backprop/forward-inl.h | 7 +++++-- backprop/forward.h | 2 -- backprop/forward_scalar.h | 23 +++++++------------- backprop/sampler.h | 3 +++ backprop/test_util.h | 1 + gemma/configs.h | 27 ++++++++++++++++++++---- gemma/gemma.cc | 11 ++++++++-- gemma/ops.h | 21 +++---------------- 13 files changed, 94 insertions(+), 100 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index c744213..88b7b0e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -361,6 +361,8 @@ cc_test( ], deps = [ ":backprop_scalar", + ":common", + ":gemma_lib", ":prompt", ":sampler", ":weights_raw", @@ -382,6 +384,7 @@ cc_test( deps = [ ":backprop", ":backprop_scalar", + ":common", ":gemma_lib", ":ops", ":sampler", diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 837dc13..1ef6658 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -307,9 +307,9 @@ void LayerVJP(const LayerT& weights, } } -static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward, +static HWY_NOINLINE void SoftcapVJP(const float cap, + const float* HWY_RESTRICT forward, float* HWY_RESTRICT backward, - const float cap, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -318,25 +318,11 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward, const auto one = hn::Set(d, 1.0f); const auto vcap = hn::Set(d, cap); const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); - - // TODO(szabadka): Investigate what to do when the argmax is not unique. - // TODO(szabadka): Use IndexOfMax from hwy when it is available. - size_t imax = std::max_element(forward, forward + size) - forward; - - hn::Transform1( - d, backward, size, forward, - [&](const auto d, const auto v, const auto y) HWY_ATTR { - const auto scaled = hn::Mul(vinv_cap, y); - return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); - }); - - backward[imax] = 0; - auto sum = hn::Zero(d); - Foreach(d, backward, size, sum, - [&sum](const auto d, const auto value) HWY_ATTR { - sum = hn::Add(sum, value); - }); - backward[imax] = -hn::ReduceSum(d, sum); + hn::Transform1(d, backward, size, forward, + [&](const auto d, const auto v, const auto y) HWY_ATTR { + const auto scaled = hn::Mul(vinv_cap, y); // = tanh + return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); + }); } static HWY_NOINLINE void CrossEntropyLossGrad( @@ -385,9 +371,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, kVocabSize); } - for (size_t pos = 0; pos < num_tokens; ++pos) { - SoftcapVJP(forward.logits.data() + pos * kVocabSize, - backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize); + if constexpr (TConfig::kFinalCap > 0.0f) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + SoftcapVJP(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize, + backward.logits.data() + pos * kVocabSize, kVocabSize); + } } MatMulVJP( diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index aa652ac..77cd76f 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -274,20 +274,12 @@ void LayerVJP(const Layer& weights, num_tokens * kModelDim); } -template -void SoftcapVJPT(const T* y, T* dy, size_t N) { - size_t imax = std::max_element(y, y + N) - y; - T cap = 30.0; - T inv_cap = T(1.0) / cap; +template +void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) { + const T inv_cap = T{1.0} / static_cast(cap); for (size_t i = 0; i < N; ++i) { - T scaled = y[i] * inv_cap; - dy[i] *= (T(1.0) - scaled * scaled); - } - dy[imax] = T(0.0); - for (size_t i = 0; i < N; ++i) { - if (i != imax) { - dy[imax] -= dy[i]; - } + T scaled = y[i] * inv_cap; // tanh + dy[i] *= (T{1.0} - scaled * scaled); } } @@ -324,10 +316,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, SoftmaxVJPT(forward.probs.data(), backward.logits.data(), kVocabSize, num_tokens); - for (size_t i = 0; i < num_tokens; ++i) { - SoftcapVJPT(forward.logits.data() + i * kVocabSize, - backward.logits.data() + i * kVocabSize, - kVocabSize); + if constexpr (TConfig::kFinalCap > 0.0f) { + for (size_t i = 0; i < num_tokens; ++i) { + SoftcapVJPT(TConfig::kFinalCap, forward.logits.data() + i * kVocabSize, + backward.logits.data() + i * kVocabSize, kVocabSize); + } } MatMulVJPT(weights.embedder_input_embedding.data(), diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index 2a9d99b..85f63bc 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -16,16 +16,23 @@ #include "backprop/backward_scalar.h" #include +#include #include // memset #include #include +#include #include +#include #include "gtest/gtest.h" +#include "backprop/common_scalar.h" #include "backprop/forward_scalar.h" +#include "backprop/prompt.h" #include "backprop/sampler.h" #include "backprop/test_util.h" +#include "gemma/activations.h" +#include "gemma/configs.h" #include "gemma/weights_raw.h" namespace gcpp { @@ -202,18 +209,19 @@ TEST(BackPropTest, SoftcapVJP) { std::array c_x; std::array c_y; + constexpr float kCap = 30.0f; for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0 * (1 << iter), gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0])); - Softcap(c_y.data(), N); + Softcap(kCap, c_y.data(), N); return DotT(dy.data(), c_y.data(), N); }; - Softcap(x.data(), N); + Softcap(kCap, x.data(), N); memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); - SoftcapVJPT(x.data(), dx.data(), N); + SoftcapVJPT(kCap, x.data(), dx.data(), N); TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); } } @@ -230,10 +238,11 @@ TEST(BackPropTest, CrossEntropyLossGrad) { Prompt prompt; prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 }; + const float kCap = 30.0f; for (int iter = 0; iter < 10; ++iter) { prompt.context_size = 1 + (iter % 6); RandInit(x, 1.0 * (1 << iter), gen); - Softcap(x.data(), V * K); + Softcap(kCap, x.data(), V * K); Softmax(x.data(), V, K); CrossEntropyLossGrad(x.data(), dx.data(), prompt, V); Complexify(x, c_x); @@ -368,7 +377,7 @@ TEST(BackPropTest, InputEmbeddingVJP) { } } -struct TestConfig { +struct TestConfig : ConfigCapNoSSM { static constexpr int kSeqLen = 18; static constexpr int kVocabSize = 12; static constexpr int kModelDim = 32; @@ -382,12 +391,7 @@ struct TestConfig { static constexpr bool kPostNormScale = false; static constexpr int kKVHeads = 1; - static constexpr int kConv1dWidth = 0; - static constexpr bool kFFBiases = false; - static constexpr bool kSoftmaxAttnOutputBiases = false; static constexpr int kGemmaLayers = kLayers; - static constexpr int kGriffinLayers = 0; - static constexpr int kNumTensorScales = 0; }; TEST(BackPropTest, LayerVJP) { diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 146ce67..432882e 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -25,10 +25,12 @@ #include #include "backprop/backward_scalar.h" +#include "backprop/common_scalar.h" #include "backprop/forward_scalar.h" #include "backprop/sampler.h" #include "backprop/test_util.h" #include "gemma/activations.h" +#include "gemma/configs.h" #include "gemma/weights_raw.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -180,7 +182,7 @@ void TestRMSNormVJP() { } } -struct TestConfig { +struct TestConfig : public ConfigCapNoSSM { static constexpr int kSeqLen = 24; static constexpr int kVocabSize = 16; static constexpr int kModelDim = 32; @@ -194,12 +196,7 @@ struct TestConfig { static constexpr bool kPostNormScale = false; static constexpr int kKVHeads = 1; - static constexpr int kConv1dWidth = 0; - static constexpr bool kFFBiases = false; - static constexpr bool kSoftmaxAttnOutputBiases = false; static constexpr int kGemmaLayers = kLayers; - static constexpr int kGriffinLayers = 0; - static constexpr int kNumTensorScales = 0; }; void TestEndToEnd() { diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index b322061..4b58036 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -266,8 +266,11 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, forward.logits.data() + pos * kVocabSize, pool); } - for (size_t pos = 0; pos < num_tokens; ++pos) { - LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize); + if constexpr (TConfig::kFinalCap > 0.0f) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + LogitsSoftCap(TConfig::kFinalCap, + forward.logits.data() + pos * kVocabSize, kVocabSize); + } } hwy::CopyBytes(forward.logits.data(), forward.probs.data(), diff --git a/backprop/forward.h b/backprop/forward.h index f17c898..4950f37 100644 --- a/backprop/forward.h +++ b/backprop/forward.h @@ -16,8 +16,6 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ -#include - #include "backprop/prompt.h" #include "gemma/common.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 95c5f0c..6fd58d2 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -93,21 +93,11 @@ void Softmax(T* x, size_t N, size_t K) { Softmax(x + i * N, N); } } -template -void Softcap(T* x, size_t N) { - auto maxreal = std::real(x[0]); - size_t imax = 0; - for (size_t i = 1; i < N; ++i) { - if (std::real(x[i]) > maxreal) { - maxreal = std::real(x[i]); - imax = i; - } - } - T cap = 30.0; - T inv_cap = T(1.0) / cap; - T xmax = x[imax]; +template +void Softcap(float cap, T* x, size_t N) { + const T inv_cap = T{1.0} / static_cast(cap); for (size_t i = 0; i < N; ++i) { - x[i] = cap * std::tanh((x[i] - xmax) * inv_cap); + x[i] = static_cast(cap) * std::tanh(x[i] * inv_cap); } } @@ -285,7 +275,10 @@ T CrossEntropyLossForwardPass(const Prompt& prompt, forward.logits.data(), kVocabSize, kModelDim, num_tokens); for (size_t pos = 0; pos < num_tokens; ++pos) { - Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize); + if constexpr (TConfig::kFinalCap > 0.0f) { + Softcap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize, + kVocabSize); + } } memcpy(forward.probs.data(), forward.logits.data(), diff --git a/backprop/sampler.h b/backprop/sampler.h index 9e67fa1..17f5762 100644 --- a/backprop/sampler.h +++ b/backprop/sampler.h @@ -16,6 +16,9 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ +#include +#include + #include #include diff --git a/backprop/test_util.h b/backprop/test_util.h index 387b979..45ab97b 100644 --- a/backprop/test_util.h +++ b/backprop/test_util.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "gtest/gtest.h" diff --git a/gemma/configs.h b/gemma/configs.h index 32d021d..1ffdfc2 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -99,8 +99,19 @@ struct ConfigNoSSM { static constexpr int kNumTensorScales = 0; }; +struct ConfigNoCapNoSSM : ConfigNoSSM { + static constexpr float kAttCap = 0.0f; + static constexpr float kFinalCap = 0.0f; +}; + +// For Gemma2 with SoftCap +struct ConfigCapNoSSM : ConfigNoSSM { + static constexpr float kAttCap = 50.0f; + static constexpr float kFinalCap = 30.0f; +}; + template -struct ConfigGemma27B : public ConfigNoSSM { +struct ConfigGemma27B : public ConfigCapNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = gcpp::kSeqLen; @@ -120,7 +131,7 @@ struct ConfigGemma27B : public ConfigNoSSM { }; template -struct ConfigGemma9B : public ConfigNoSSM { +struct ConfigGemma9B : public ConfigCapNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = gcpp::kSeqLen; @@ -140,7 +151,7 @@ struct ConfigGemma9B : public ConfigNoSSM { }; template -struct ConfigGemma7B : public ConfigNoSSM { +struct ConfigGemma7B : public ConfigNoCapNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = gcpp::kSeqLen; @@ -160,7 +171,7 @@ struct ConfigGemma7B : public ConfigNoSSM { }; template -struct ConfigGemma2B : public ConfigNoSSM { +struct ConfigGemma2B : public ConfigNoCapNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = gcpp::kSeqLen; @@ -197,6 +208,10 @@ struct ConfigGemmaTiny : public ConfigNoSSM { static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; static constexpr bool kPostNormScale = false; + + static constexpr float kAttCap = 0.0f; + // This is required for optimize_test to pass. + static constexpr float kFinalCap = 30.0f; }; template @@ -251,6 +266,10 @@ struct ConfigGriffin2B { static constexpr bool kAbsolutePE = false; static constexpr bool kPostNormScale = false; + // No SoftCap. + static constexpr float kAttCap = 0.0f; + static constexpr float kFinalCap = 0.0f; + // SSM config. static constexpr int kConv1dWidth = 4; static constexpr bool kFFBiases = true; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 09a5aab..3cfceb3 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -484,7 +484,11 @@ HWY_NOINLINE void Attention( const float score = Dot(q, k2, kQKVDim); head_att[pos2 % kSeqLen] = score; } - Softmax(head_att, std::min(pos + 1, kSeqLen)); + const size_t head_att_len = std::min(pos + 1, kSeqLen); + if constexpr (TConfig::kAttCap > 0.0f) { + LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); + } + Softmax(head_att, head_att_len); // Weighted summation float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + @@ -979,7 +983,10 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, MatVec( weights.embedder_input_embedding, 0, x, activations.even_odd.data(), logits, pool); - LogitsSoftCap(30.0f, logits, kVocabSize); + if constexpr (TConfig::kFinalCap > 0.0f) { + LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(), + kVocabSize); + } // Barrier: must have all logits so we can subtract max. Softmax(logits, kVocabSize); token = sample_token(logits, kVocabSize); diff --git a/gemma/ops.h b/gemma/ops.h index 0396883..cdb7541 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -1773,24 +1773,9 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, const V vcap = hn::Set(d, cap); const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); - // If we do not subtract the max as in softmax, values > 100 (which do occur) - // will all saturate to the cap, and then this function is no longer - // monotonic, which would change the results on TopK. - const V vmin = hn::Set(d, hwy::LowestValue()); - V vmax = vmin; - Foreach(d, x, max_pos, vmin, - [&vmax](const auto d, const auto value) - HWY_ATTR { vmax = hn::Max(vmax, value); }); - vmax = hn::MaxOfLanes(d, vmax); - - // We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to - // v * vinv_cap + (-vmax*vinv_cap). - const V add = hn::Neg(hn::Mul(vmax, vinv_cap)); - - hn::Transform( - d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec v) HWY_ATTR { - return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add))); - }); + hn::Transform(d, x, max_pos, [&vcap, &vinv_cap](D d, hn::Vec v) HWY_ATTR { + return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap))); + }); } static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,