Fix numerical issue in Softcap by subtracting max.

Also update test threshold.

PiperOrigin-RevId: 642587468
This commit is contained in:
The gemma.cpp Authors 2024-06-12 05:41:33 -07:00 committed by Copybara-Service
parent e37447cfe2
commit 2a0e6ee976
6 changed files with 82 additions and 28 deletions

View File

@ -22,6 +22,7 @@
#include <stddef.h>
#include <algorithm>
#include <array>
#include <cmath>
@ -319,12 +320,24 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
// TODO(szabadka): Investigate what to do when the argmax is not unique.
// TODO(szabadka): Use IndexOfMax from hwy when it is available.
size_t imax = std::max_element(forward, forward + size) - forward;
hn::Transform1(
d, backward, size, forward,
[&](const auto d, const auto v, const auto y) HWY_ATTR {
const auto scaled = hn::Mul(vinv_cap, y);
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
});
backward[imax] = 0;
auto sum = hn::Zero(d);
Foreach(d, backward, size, sum,
[&sum](const auto d, const auto value) HWY_ATTR {
sum = hn::Add(sum, value);
});
backward[imax] = -hn::ReduceSum(d, sum);
}
static HWY_NOINLINE void CrossEntropyLossGrad(

View File

@ -277,12 +277,19 @@ void LayerVJP(const Layer<T, TConfig>& weights,
template<typename T>
void SoftcapVJPT(const T* y, T* dy, size_t N) {
size_t imax = std::max_element(y, y + N) - y;
T cap = 30.0;
T inv_cap = T(1.0) / cap;
for (size_t i = 0; i < N; ++i) {
T scaled = y[i] * inv_cap;
dy[i] *= (T(1.0) - scaled * scaled);
}
dy[imax] = T(0.0);
for (size_t i = 0; i < N; ++i) {
if (i != imax) {
dy[imax] -= dy[i];
}
}
}
template<typename T>
@ -318,8 +325,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
kVocabSize, num_tokens);
SoftcapVJPT(forward.logits.data(), backward.logits.data(),
num_tokens * kVocabSize);
for (size_t i = 0; i < num_tokens; ++i) {
SoftcapVJPT(forward.logits.data() + i * kVocabSize,
backward.logits.data() + i * kVocabSize,
kVocabSize);
}
MatMulVJPT(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(),

View File

@ -210,7 +210,7 @@ TEST(BackPropTest, SoftcapVJP) {
Softcap(x.data(), N);
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
SoftcapVJPT(x.data(), dx.data(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
}
}

View File

@ -95,10 +95,19 @@ void Softmax(T* x, size_t N, size_t K) {
}
template<typename T>
void Softcap(T* x, size_t N) {
auto maxreal = std::real(x[0]);
size_t imax = 0;
for (size_t i = 1; i < N; ++i) {
if (std::real(x[i]) > maxreal) {
maxreal = std::real(x[i]);
imax = i;
}
}
T cap = 30.0;
T inv_cap = T(1.0) / cap;
T xmax = x[imax];
for (size_t i = 0; i < N; ++i) {
x[i] = cap * std::tanh(x[i] * inv_cap);
x[i] = cap * std::tanh((x[i] - xmax) * inv_cap);
}
}
@ -275,7 +284,9 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
forward.final_norm_output.data(),
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
Softcap(forward.logits.data(), num_tokens * kVocabSize);
for (size_t pos = 0; pos < num_tokens; ++pos) {
Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize);
}
memcpy(forward.probs.data(), forward.logits.data(),
num_tokens * kVocabSize * sizeof(forward.logits[0]));

View File

@ -23,6 +23,7 @@
#include <vector>
// Placeholder for internal header, do not modify.
#include "gemma/common.h"
#include "gemma/cross_entropy.h"
#include "gemma/ops.h"
#include "util/app.h"
@ -38,20 +39,22 @@ char** s_argv = nullptr;
class GemmaTest : public ::testing::Test {
protected:
static void SetUpTestSuite() {
gcpp::LoaderArgs loader(s_argc, s_argv);
gcpp::AppArgs app(s_argc, s_argv);
gcpp::LoaderArgs loader(s_argc, s_argv);
if (const char* err = loader.Validate()) {
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n");
} else {
fprintf(stderr, "Loading model..\n");
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads);
s_model = AllocateGemma(loader, *s_pool);
s_gemma = AllocateGemma(loader, *s_pool);
s_kv_cache = KVCache::Create(loader.ModelType());
s_model = loader.ModelType();
}
}
static void TearDownTestSuite() {
s_pool.reset();
s_model.reset();
s_gemma.reset();
}
std::string GemmaReply(const std::string& prompt_string) {
@ -59,7 +62,7 @@ class GemmaTest : public ::testing::Test {
gen.seed(42);
std::vector<int> prompt;
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt));
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), BOS_ID);
@ -78,25 +81,25 @@ class GemmaTest : public ::testing::Test {
.stream_token = stream_token,
};
gcpp::TimingInfo timing_info;
s_model->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
timing_info, /*layers_output=*/nullptr);
std::string response_text;
HWY_ASSERT(s_model->Tokenizer().Decode(response, &response_text));
HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text));
return response_text;
}
float GemmaCrossEntropy(const std::string& prompt_string) {
std::vector<int> prompt;
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt));
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*s_model, /*max_tokens=*/3072, prompt,
return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt,
s_kv_cache,
/*verbosity=*/0) /
prompt_string.size();
}
void TestQuestions(const char* kQA[][2], size_t num_questions) {
if (!s_model) return;
if (!s_gemma) return;
for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]);
@ -106,13 +109,15 @@ class GemmaTest : public ::testing::Test {
}
static std::unique_ptr<hwy::ThreadPool> s_pool;
static std::unique_ptr<gcpp::Gemma> s_model;
static std::unique_ptr<gcpp::Gemma> s_gemma;
static gcpp::KVCache s_kv_cache;
static gcpp::Model s_model;
};
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_model;
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_gemma;
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
/*static*/ gcpp::Model GemmaTest::s_model;
TEST_F(GemmaTest, Geography) {
static const char* kQA[][2] = {
@ -176,27 +181,26 @@ static const char kGettysburg[] = {
"people, for the people, shall not perish from the earth.\n"};
TEST_F(GemmaTest, CrossEntropySmall) {
if (!s_model) return;
if (!s_gemma) return;
static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = GemmaCrossEntropy(kSmall);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
// Note that entropy is 3x higher for the 7b-it model.
EXPECT_LT(entropy, 1.7f);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
}
TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_model) return;
if (!s_gemma) return;
float entropy = GemmaCrossEntropy(kJingleBells);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, 1.7f);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
}
TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_model) return;
if (!s_gemma) return;
float entropy = GemmaCrossEntropy(kGettysburg);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, 1.2f);
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
}
} // namespace

View File

@ -1384,13 +1384,29 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D d;
using V = hn::Vec<D>;
const auto vcap = hn::Set(d, cap);
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
const V vcap = hn::Set(d, cap);
const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
hn::Transform(d, x, size, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
});
// If we do not subtract the max as in softmax, values > 100 (which do occur)
// will all saturate to the cap, and then this function is no longer
// monotonic, which would change the results on TopK.
const V vmin = hn::Set(d, hwy::LowestValue<float>());
V vmax = vmin;
Foreach(d, x, max_pos, vmin,
[&vmax](const auto d, const auto value)
HWY_ATTR { vmax = hn::Max(vmax, value); });
vmax = hn::MaxOfLanes(d, vmax);
// We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to
// v * vinv_cap + (-vmax*vinv_cap).
const V add = hn::Neg(hn::Mul(vmax, vinv_cap));
hn::Transform(
d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec<D> v) HWY_ATTR {
return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add)));
});
}
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,