Fix numerical issue in Softcap by subtracting max.

Also update test threshold.

PiperOrigin-RevId: 642587468
This commit is contained in:
The gemma.cpp Authors 2024-06-12 05:41:33 -07:00 committed by Copybara-Service
parent e37447cfe2
commit 2a0e6ee976
6 changed files with 82 additions and 28 deletions

View File

@ -22,6 +22,7 @@
#include <stddef.h> #include <stddef.h>
#include <algorithm>
#include <array> #include <array>
#include <cmath> #include <cmath>
@ -319,12 +320,24 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
const auto vcap = hn::Set(d, cap); const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
// TODO(szabadka): Investigate what to do when the argmax is not unique.
// TODO(szabadka): Use IndexOfMax from hwy when it is available.
size_t imax = std::max_element(forward, forward + size) - forward;
hn::Transform1( hn::Transform1(
d, backward, size, forward, d, backward, size, forward,
[&](const auto d, const auto v, const auto y) HWY_ATTR { [&](const auto d, const auto v, const auto y) HWY_ATTR {
const auto scaled = hn::Mul(vinv_cap, y); const auto scaled = hn::Mul(vinv_cap, y);
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
}); });
backward[imax] = 0;
auto sum = hn::Zero(d);
Foreach(d, backward, size, sum,
[&sum](const auto d, const auto value) HWY_ATTR {
sum = hn::Add(sum, value);
});
backward[imax] = -hn::ReduceSum(d, sum);
} }
static HWY_NOINLINE void CrossEntropyLossGrad( static HWY_NOINLINE void CrossEntropyLossGrad(

View File

@ -277,12 +277,19 @@ void LayerVJP(const Layer<T, TConfig>& weights,
template<typename T> template<typename T>
void SoftcapVJPT(const T* y, T* dy, size_t N) { void SoftcapVJPT(const T* y, T* dy, size_t N) {
size_t imax = std::max_element(y, y + N) - y;
T cap = 30.0; T cap = 30.0;
T inv_cap = T(1.0) / cap; T inv_cap = T(1.0) / cap;
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
T scaled = y[i] * inv_cap; T scaled = y[i] * inv_cap;
dy[i] *= (T(1.0) - scaled * scaled); dy[i] *= (T(1.0) - scaled * scaled);
} }
dy[imax] = T(0.0);
for (size_t i = 0; i < N; ++i) {
if (i != imax) {
dy[imax] -= dy[i];
}
}
} }
template<typename T> template<typename T>
@ -318,8 +325,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
SoftmaxVJPT(forward.probs.data(), backward.logits.data(), SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
kVocabSize, num_tokens); kVocabSize, num_tokens);
SoftcapVJPT(forward.logits.data(), backward.logits.data(), for (size_t i = 0; i < num_tokens; ++i) {
num_tokens * kVocabSize); SoftcapVJPT(forward.logits.data() + i * kVocabSize,
backward.logits.data() + i * kVocabSize,
kVocabSize);
}
MatMulVJPT(weights.embedder_input_embedding.data(), MatMulVJPT(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(), forward.final_norm_output.data(),

View File

@ -210,7 +210,7 @@ TEST(BackPropTest, SoftcapVJP) {
Softcap(x.data(), N); Softcap(x.data(), N);
memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
SoftcapVJPT(x.data(), dx.data(), N); SoftcapVJPT(x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
} }
} }

View File

@ -95,10 +95,19 @@ void Softmax(T* x, size_t N, size_t K) {
} }
template<typename T> template<typename T>
void Softcap(T* x, size_t N) { void Softcap(T* x, size_t N) {
auto maxreal = std::real(x[0]);
size_t imax = 0;
for (size_t i = 1; i < N; ++i) {
if (std::real(x[i]) > maxreal) {
maxreal = std::real(x[i]);
imax = i;
}
}
T cap = 30.0; T cap = 30.0;
T inv_cap = T(1.0) / cap; T inv_cap = T(1.0) / cap;
T xmax = x[imax];
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
x[i] = cap * std::tanh(x[i] * inv_cap); x[i] = cap * std::tanh((x[i] - xmax) * inv_cap);
} }
} }
@ -275,7 +284,9 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
forward.final_norm_output.data(), forward.final_norm_output.data(),
forward.logits.data(), kVocabSize, kModelDim, num_tokens); forward.logits.data(), kVocabSize, kModelDim, num_tokens);
Softcap(forward.logits.data(), num_tokens * kVocabSize); for (size_t pos = 0; pos < num_tokens; ++pos) {
Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize);
}
memcpy(forward.probs.data(), forward.logits.data(), memcpy(forward.probs.data(), forward.logits.data(),
num_tokens * kVocabSize * sizeof(forward.logits[0])); num_tokens * kVocabSize * sizeof(forward.logits[0]));

View File

@ -23,6 +23,7 @@
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. // Placeholder for internal header, do not modify.
#include "gemma/common.h"
#include "gemma/cross_entropy.h" #include "gemma/cross_entropy.h"
#include "gemma/ops.h" #include "gemma/ops.h"
#include "util/app.h" #include "util/app.h"
@ -38,20 +39,22 @@ char** s_argv = nullptr;
class GemmaTest : public ::testing::Test { class GemmaTest : public ::testing::Test {
protected: protected:
static void SetUpTestSuite() { static void SetUpTestSuite() {
gcpp::LoaderArgs loader(s_argc, s_argv);
gcpp::AppArgs app(s_argc, s_argv); gcpp::AppArgs app(s_argc, s_argv);
gcpp::LoaderArgs loader(s_argc, s_argv);
if (const char* err = loader.Validate()) { if (const char* err = loader.Validate()) {
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n"); fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n");
} else { } else {
fprintf(stderr, "Loading model..\n");
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads); s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads);
s_model = AllocateGemma(loader, *s_pool); s_gemma = AllocateGemma(loader, *s_pool);
s_kv_cache = KVCache::Create(loader.ModelType()); s_kv_cache = KVCache::Create(loader.ModelType());
s_model = loader.ModelType();
} }
} }
static void TearDownTestSuite() { static void TearDownTestSuite() {
s_pool.reset(); s_pool.reset();
s_model.reset(); s_gemma.reset();
} }
std::string GemmaReply(const std::string& prompt_string) { std::string GemmaReply(const std::string& prompt_string) {
@ -59,7 +62,7 @@ class GemmaTest : public ::testing::Test {
gen.seed(42); gen.seed(42);
std::vector<int> prompt; std::vector<int> prompt;
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt)); HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token // For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed. // if needed.
prompt.insert(prompt.begin(), BOS_ID); prompt.insert(prompt.begin(), BOS_ID);
@ -78,25 +81,25 @@ class GemmaTest : public ::testing::Test {
.stream_token = stream_token, .stream_token = stream_token,
}; };
gcpp::TimingInfo timing_info; gcpp::TimingInfo timing_info;
s_model->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache, s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
timing_info, /*layers_output=*/nullptr); timing_info, /*layers_output=*/nullptr);
std::string response_text; std::string response_text;
HWY_ASSERT(s_model->Tokenizer().Decode(response, &response_text)); HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text));
return response_text; return response_text;
} }
float GemmaCrossEntropy(const std::string& prompt_string) { float GemmaCrossEntropy(const std::string& prompt_string) {
std::vector<int> prompt; std::vector<int> prompt;
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt)); HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
prompt.insert(prompt.begin(), BOS_ID); prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*s_model, /*max_tokens=*/3072, prompt, return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt,
s_kv_cache, s_kv_cache,
/*verbosity=*/0) / /*verbosity=*/0) /
prompt_string.size(); prompt_string.size();
} }
void TestQuestions(const char* kQA[][2], size_t num_questions) { void TestQuestions(const char* kQA[][2], size_t num_questions) {
if (!s_model) return; if (!s_gemma) return;
for (size_t i = 0; i < num_questions; ++i) { for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1); fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]); std::string response = GemmaReply(kQA[i][0]);
@ -106,13 +109,15 @@ class GemmaTest : public ::testing::Test {
} }
static std::unique_ptr<hwy::ThreadPool> s_pool; static std::unique_ptr<hwy::ThreadPool> s_pool;
static std::unique_ptr<gcpp::Gemma> s_model; static std::unique_ptr<gcpp::Gemma> s_gemma;
static gcpp::KVCache s_kv_cache; static gcpp::KVCache s_kv_cache;
static gcpp::Model s_model;
}; };
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool; /*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_model; /*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_gemma;
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache; /*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
/*static*/ gcpp::Model GemmaTest::s_model;
TEST_F(GemmaTest, Geography) { TEST_F(GemmaTest, Geography) {
static const char* kQA[][2] = { static const char* kQA[][2] = {
@ -176,27 +181,26 @@ static const char kGettysburg[] = {
"people, for the people, shall not perish from the earth.\n"}; "people, for the people, shall not perish from the earth.\n"};
TEST_F(GemmaTest, CrossEntropySmall) { TEST_F(GemmaTest, CrossEntropySmall) {
if (!s_model) return; if (!s_gemma) return;
static const char kSmall[] = static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe."; "The capital of Hungary is Budapest which is located in Europe.";
float entropy = GemmaCrossEntropy(kSmall); float entropy = GemmaCrossEntropy(kSmall);
fprintf(stderr, "per-byte entropy: %f\n", entropy); fprintf(stderr, "per-byte entropy: %f\n", entropy);
// Note that entropy is 3x higher for the 7b-it model. EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
EXPECT_LT(entropy, 1.7f);
} }
TEST_F(GemmaTest, CrossEntropyJingleBells) { TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_model) return; if (!s_gemma) return;
float entropy = GemmaCrossEntropy(kJingleBells); float entropy = GemmaCrossEntropy(kJingleBells);
fprintf(stderr, "per-byte entropy: %f\n", entropy); fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, 1.7f); EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
} }
TEST_F(GemmaTest, CrossEntropyGettysburg) { TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_model) return; if (!s_gemma) return;
float entropy = GemmaCrossEntropy(kGettysburg); float entropy = GemmaCrossEntropy(kGettysburg);
fprintf(stderr, "per-byte entropy: %f\n", entropy); fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, 1.2f); EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
} }
} // namespace } // namespace

View File

@ -1384,12 +1384,28 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
const D d; const D d;
using V = hn::Vec<D>;
const auto vcap = hn::Set(d, cap); const V vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
hn::Transform(d, x, size, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR { // If we do not subtract the max as in softmax, values > 100 (which do occur)
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap))); // will all saturate to the cap, and then this function is no longer
// monotonic, which would change the results on TopK.
const V vmin = hn::Set(d, hwy::LowestValue<float>());
V vmax = vmin;
Foreach(d, x, max_pos, vmin,
[&vmax](const auto d, const auto value)
HWY_ATTR { vmax = hn::Max(vmax, value); });
vmax = hn::MaxOfLanes(d, vmax);
// We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to
// v * vinv_cap + (-vmax*vinv_cap).
const V add = hn::Neg(hn::Mul(vmax, vinv_cap));
hn::Transform(
d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add)));
}); });
} }