diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 96f1f08..49d866e 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -22,6 +22,7 @@ #include +#include #include #include @@ -319,12 +320,24 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward, const auto vcap = hn::Set(d, cap); const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); + // TODO(szabadka): Investigate what to do when the argmax is not unique. + // TODO(szabadka): Use IndexOfMax from hwy when it is available. + size_t imax = std::max_element(forward, forward + size) - forward; + hn::Transform1( d, backward, size, forward, [&](const auto d, const auto v, const auto y) HWY_ATTR { const auto scaled = hn::Mul(vinv_cap, y); return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); }); + + backward[imax] = 0; + auto sum = hn::Zero(d); + Foreach(d, backward, size, sum, + [&sum](const auto d, const auto value) HWY_ATTR { + sum = hn::Add(sum, value); + }); + backward[imax] = -hn::ReduceSum(d, sum); } static HWY_NOINLINE void CrossEntropyLossGrad( diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index ed18546..024e2a4 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -277,12 +277,19 @@ void LayerVJP(const Layer& weights, template void SoftcapVJPT(const T* y, T* dy, size_t N) { + size_t imax = std::max_element(y, y + N) - y; T cap = 30.0; T inv_cap = T(1.0) / cap; for (size_t i = 0; i < N; ++i) { T scaled = y[i] * inv_cap; dy[i] *= (T(1.0) - scaled * scaled); } + dy[imax] = T(0.0); + for (size_t i = 0; i < N; ++i) { + if (i != imax) { + dy[imax] -= dy[i]; + } + } } template @@ -318,8 +325,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, SoftmaxVJPT(forward.probs.data(), backward.logits.data(), kVocabSize, num_tokens); - SoftcapVJPT(forward.logits.data(), backward.logits.data(), - num_tokens * kVocabSize); + for (size_t i = 0; i < num_tokens; ++i) { + SoftcapVJPT(forward.logits.data() + i * kVocabSize, + backward.logits.data() + i * kVocabSize, + kVocabSize); + } MatMulVJPT(weights.embedder_input_embedding.data(), forward.final_norm_output.data(), diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index 48d7234..9a94484 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -210,7 +210,7 @@ TEST(BackPropTest, SoftcapVJP) { Softcap(x.data(), N); memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); SoftcapVJPT(x.data(), dx.data(), N); - TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); + TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); } } diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 475f7c3..8bc125e 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -95,10 +95,19 @@ void Softmax(T* x, size_t N, size_t K) { } template void Softcap(T* x, size_t N) { + auto maxreal = std::real(x[0]); + size_t imax = 0; + for (size_t i = 1; i < N; ++i) { + if (std::real(x[i]) > maxreal) { + maxreal = std::real(x[i]); + imax = i; + } + } T cap = 30.0; T inv_cap = T(1.0) / cap; + T xmax = x[imax]; for (size_t i = 0; i < N; ++i) { - x[i] = cap * std::tanh(x[i] * inv_cap); + x[i] = cap * std::tanh((x[i] - xmax) * inv_cap); } } @@ -275,7 +284,9 @@ T CrossEntropyLossForwardPass(const Prompt& prompt, forward.final_norm_output.data(), forward.logits.data(), kVocabSize, kModelDim, num_tokens); - Softcap(forward.logits.data(), num_tokens * kVocabSize); + for (size_t pos = 0; pos < num_tokens; ++pos) { + Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize); + } memcpy(forward.probs.data(), forward.logits.data(), num_tokens * kVocabSize * sizeof(forward.logits[0])); diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index 4086cbf..8817a0c 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -23,6 +23,7 @@ #include // Placeholder for internal header, do not modify. +#include "gemma/common.h" #include "gemma/cross_entropy.h" #include "gemma/ops.h" #include "util/app.h" @@ -38,20 +39,22 @@ char** s_argv = nullptr; class GemmaTest : public ::testing::Test { protected: static void SetUpTestSuite() { - gcpp::LoaderArgs loader(s_argc, s_argv); gcpp::AppArgs app(s_argc, s_argv); + gcpp::LoaderArgs loader(s_argc, s_argv); if (const char* err = loader.Validate()) { fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n"); } else { + fprintf(stderr, "Loading model..\n"); s_pool = std::make_unique(app.num_threads); - s_model = AllocateGemma(loader, *s_pool); + s_gemma = AllocateGemma(loader, *s_pool); s_kv_cache = KVCache::Create(loader.ModelType()); + s_model = loader.ModelType(); } } static void TearDownTestSuite() { s_pool.reset(); - s_model.reset(); + s_gemma.reset(); } std::string GemmaReply(const std::string& prompt_string) { @@ -59,7 +62,7 @@ class GemmaTest : public ::testing::Test { gen.seed(42); std::vector prompt; - HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt)); + HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt)); // For both pre-trained and instruction-tuned models: prepend "" token // if needed. prompt.insert(prompt.begin(), BOS_ID); @@ -78,25 +81,25 @@ class GemmaTest : public ::testing::Test { .stream_token = stream_token, }; gcpp::TimingInfo timing_info; - s_model->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache, + s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache, timing_info, /*layers_output=*/nullptr); std::string response_text; - HWY_ASSERT(s_model->Tokenizer().Decode(response, &response_text)); + HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text)); return response_text; } float GemmaCrossEntropy(const std::string& prompt_string) { std::vector prompt; - HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt)); + HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt)); prompt.insert(prompt.begin(), BOS_ID); - return ComputeCrossEntropy(*s_model, /*max_tokens=*/3072, prompt, + return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt, s_kv_cache, /*verbosity=*/0) / prompt_string.size(); } void TestQuestions(const char* kQA[][2], size_t num_questions) { - if (!s_model) return; + if (!s_gemma) return; for (size_t i = 0; i < num_questions; ++i) { fprintf(stderr, "Question %zu\n\n", i + 1); std::string response = GemmaReply(kQA[i][0]); @@ -106,13 +109,15 @@ class GemmaTest : public ::testing::Test { } static std::unique_ptr s_pool; - static std::unique_ptr s_model; + static std::unique_ptr s_gemma; static gcpp::KVCache s_kv_cache; + static gcpp::Model s_model; }; /*static*/ std::unique_ptr GemmaTest::s_pool; -/*static*/ std::unique_ptr GemmaTest::s_model; +/*static*/ std::unique_ptr GemmaTest::s_gemma; /*static*/ gcpp::KVCache GemmaTest::s_kv_cache; +/*static*/ gcpp::Model GemmaTest::s_model; TEST_F(GemmaTest, Geography) { static const char* kQA[][2] = { @@ -176,27 +181,26 @@ static const char kGettysburg[] = { "people, for the people, shall not perish from the earth.\n"}; TEST_F(GemmaTest, CrossEntropySmall) { - if (!s_model) return; + if (!s_gemma) return; static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; float entropy = GemmaCrossEntropy(kSmall); fprintf(stderr, "per-byte entropy: %f\n", entropy); - // Note that entropy is 3x higher for the 7b-it model. - EXPECT_LT(entropy, 1.7f); + EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f); } TEST_F(GemmaTest, CrossEntropyJingleBells) { - if (!s_model) return; + if (!s_gemma) return; float entropy = GemmaCrossEntropy(kJingleBells); fprintf(stderr, "per-byte entropy: %f\n", entropy); - EXPECT_LT(entropy, 1.7f); + EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f); } TEST_F(GemmaTest, CrossEntropyGettysburg) { - if (!s_model) return; + if (!s_gemma) return; float entropy = GemmaCrossEntropy(kGettysburg); fprintf(stderr, "per-byte entropy: %f\n", entropy); - EXPECT_LT(entropy, 1.2f); + EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f); } } // namespace diff --git a/gemma/ops.h b/gemma/ops.h index ff6ba58..ca6fa8e 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -1384,13 +1384,29 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; + using V = hn::Vec; - const auto vcap = hn::Set(d, cap); - const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); + const V vcap = hn::Set(d, cap); + const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); - hn::Transform(d, x, size, [&vcap, &vinv_cap](D d, hn::Vec v) HWY_ATTR { - return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap))); - }); + // If we do not subtract the max as in softmax, values > 100 (which do occur) + // will all saturate to the cap, and then this function is no longer + // monotonic, which would change the results on TopK. + const V vmin = hn::Set(d, hwy::LowestValue()); + V vmax = vmin; + Foreach(d, x, max_pos, vmin, + [&vmax](const auto d, const auto value) + HWY_ATTR { vmax = hn::Max(vmax, value); }); + vmax = hn::MaxOfLanes(d, vmax); + + // We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to + // v * vinv_cap + (-vmax*vinv_cap). + const V add = hn::Neg(hn::Mul(vmax, vinv_cap)); + + hn::Transform( + d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec v) HWY_ATTR { + return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add))); + }); } static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,