mirror of https://github.com/google/gemma.cpp.git
Fix numerical issue in Softcap by subtracting max.
Also update test threshold. PiperOrigin-RevId: 642587468
This commit is contained in:
parent
e37447cfe2
commit
2a0e6ee976
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
|
|
@ -319,12 +320,24 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
|
|||
const auto vcap = hn::Set(d, cap);
|
||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||
|
||||
// TODO(szabadka): Investigate what to do when the argmax is not unique.
|
||||
// TODO(szabadka): Use IndexOfMax from hwy when it is available.
|
||||
size_t imax = std::max_element(forward, forward + size) - forward;
|
||||
|
||||
hn::Transform1(
|
||||
d, backward, size, forward,
|
||||
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
||||
const auto scaled = hn::Mul(vinv_cap, y);
|
||||
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
||||
});
|
||||
|
||||
backward[imax] = 0;
|
||||
auto sum = hn::Zero(d);
|
||||
Foreach(d, backward, size, sum,
|
||||
[&sum](const auto d, const auto value) HWY_ATTR {
|
||||
sum = hn::Add(sum, value);
|
||||
});
|
||||
backward[imax] = -hn::ReduceSum(d, sum);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||
|
|
|
|||
|
|
@ -277,12 +277,19 @@ void LayerVJP(const Layer<T, TConfig>& weights,
|
|||
|
||||
template<typename T>
|
||||
void SoftcapVJPT(const T* y, T* dy, size_t N) {
|
||||
size_t imax = std::max_element(y, y + N) - y;
|
||||
T cap = 30.0;
|
||||
T inv_cap = T(1.0) / cap;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
T scaled = y[i] * inv_cap;
|
||||
dy[i] *= (T(1.0) - scaled * scaled);
|
||||
}
|
||||
dy[imax] = T(0.0);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
if (i != imax) {
|
||||
dy[imax] -= dy[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -318,8 +325,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
|||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
|
||||
kVocabSize, num_tokens);
|
||||
|
||||
SoftcapVJPT(forward.logits.data(), backward.logits.data(),
|
||||
num_tokens * kVocabSize);
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
SoftcapVJPT(forward.logits.data() + i * kVocabSize,
|
||||
backward.logits.data() + i * kVocabSize,
|
||||
kVocabSize);
|
||||
}
|
||||
|
||||
MatMulVJPT(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(),
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ TEST(BackPropTest, SoftcapVJP) {
|
|||
Softcap(x.data(), N);
|
||||
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
|
||||
SoftcapVJPT(x.data(), dx.data(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -95,10 +95,19 @@ void Softmax(T* x, size_t N, size_t K) {
|
|||
}
|
||||
template<typename T>
|
||||
void Softcap(T* x, size_t N) {
|
||||
auto maxreal = std::real(x[0]);
|
||||
size_t imax = 0;
|
||||
for (size_t i = 1; i < N; ++i) {
|
||||
if (std::real(x[i]) > maxreal) {
|
||||
maxreal = std::real(x[i]);
|
||||
imax = i;
|
||||
}
|
||||
}
|
||||
T cap = 30.0;
|
||||
T inv_cap = T(1.0) / cap;
|
||||
T xmax = x[imax];
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] = cap * std::tanh(x[i] * inv_cap);
|
||||
x[i] = cap * std::tanh((x[i] - xmax) * inv_cap);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -275,7 +284,9 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
|
|||
forward.final_norm_output.data(),
|
||||
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
|
||||
|
||||
Softcap(forward.logits.data(), num_tokens * kVocabSize);
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
|
||||
memcpy(forward.probs.data(), forward.logits.data(),
|
||||
num_tokens * kVocabSize * sizeof(forward.logits[0]));
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/cross_entropy.h"
|
||||
#include "gemma/ops.h"
|
||||
#include "util/app.h"
|
||||
|
|
@ -38,20 +39,22 @@ char** s_argv = nullptr;
|
|||
class GemmaTest : public ::testing::Test {
|
||||
protected:
|
||||
static void SetUpTestSuite() {
|
||||
gcpp::LoaderArgs loader(s_argc, s_argv);
|
||||
gcpp::AppArgs app(s_argc, s_argv);
|
||||
gcpp::LoaderArgs loader(s_argc, s_argv);
|
||||
if (const char* err = loader.Validate()) {
|
||||
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n");
|
||||
} else {
|
||||
fprintf(stderr, "Loading model..\n");
|
||||
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads);
|
||||
s_model = AllocateGemma(loader, *s_pool);
|
||||
s_gemma = AllocateGemma(loader, *s_pool);
|
||||
s_kv_cache = KVCache::Create(loader.ModelType());
|
||||
s_model = loader.ModelType();
|
||||
}
|
||||
}
|
||||
|
||||
static void TearDownTestSuite() {
|
||||
s_pool.reset();
|
||||
s_model.reset();
|
||||
s_gemma.reset();
|
||||
}
|
||||
|
||||
std::string GemmaReply(const std::string& prompt_string) {
|
||||
|
|
@ -59,7 +62,7 @@ class GemmaTest : public ::testing::Test {
|
|||
gen.seed(42);
|
||||
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt));
|
||||
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
prompt.insert(prompt.begin(), BOS_ID);
|
||||
|
|
@ -78,25 +81,25 @@ class GemmaTest : public ::testing::Test {
|
|||
.stream_token = stream_token,
|
||||
};
|
||||
gcpp::TimingInfo timing_info;
|
||||
s_model->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
|
||||
s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
|
||||
timing_info, /*layers_output=*/nullptr);
|
||||
std::string response_text;
|
||||
HWY_ASSERT(s_model->Tokenizer().Decode(response, &response_text));
|
||||
HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text));
|
||||
return response_text;
|
||||
}
|
||||
|
||||
float GemmaCrossEntropy(const std::string& prompt_string) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt));
|
||||
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
|
||||
prompt.insert(prompt.begin(), BOS_ID);
|
||||
return ComputeCrossEntropy(*s_model, /*max_tokens=*/3072, prompt,
|
||||
return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt,
|
||||
s_kv_cache,
|
||||
/*verbosity=*/0) /
|
||||
prompt_string.size();
|
||||
}
|
||||
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||
if (!s_model) return;
|
||||
if (!s_gemma) return;
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Question %zu\n\n", i + 1);
|
||||
std::string response = GemmaReply(kQA[i][0]);
|
||||
|
|
@ -106,13 +109,15 @@ class GemmaTest : public ::testing::Test {
|
|||
}
|
||||
|
||||
static std::unique_ptr<hwy::ThreadPool> s_pool;
|
||||
static std::unique_ptr<gcpp::Gemma> s_model;
|
||||
static std::unique_ptr<gcpp::Gemma> s_gemma;
|
||||
static gcpp::KVCache s_kv_cache;
|
||||
static gcpp::Model s_model;
|
||||
};
|
||||
|
||||
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
|
||||
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_model;
|
||||
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_gemma;
|
||||
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
|
||||
/*static*/ gcpp::Model GemmaTest::s_model;
|
||||
|
||||
TEST_F(GemmaTest, Geography) {
|
||||
static const char* kQA[][2] = {
|
||||
|
|
@ -176,27 +181,26 @@ static const char kGettysburg[] = {
|
|||
"people, for the people, shall not perish from the earth.\n"};
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||
if (!s_model) return;
|
||||
if (!s_gemma) return;
|
||||
static const char kSmall[] =
|
||||
"The capital of Hungary is Budapest which is located in Europe.";
|
||||
float entropy = GemmaCrossEntropy(kSmall);
|
||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||
// Note that entropy is 3x higher for the 7b-it model.
|
||||
EXPECT_LT(entropy, 1.7f);
|
||||
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||
if (!s_model) return;
|
||||
if (!s_gemma) return;
|
||||
float entropy = GemmaCrossEntropy(kJingleBells);
|
||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||
EXPECT_LT(entropy, 1.7f);
|
||||
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||
if (!s_model) return;
|
||||
if (!s_gemma) return;
|
||||
float entropy = GemmaCrossEntropy(kGettysburg);
|
||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||
EXPECT_LT(entropy, 1.2f);
|
||||
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
24
gemma/ops.h
24
gemma/ops.h
|
|
@ -1384,12 +1384,28 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
|||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
using V = hn::Vec<D>;
|
||||
|
||||
const auto vcap = hn::Set(d, cap);
|
||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||
const V vcap = hn::Set(d, cap);
|
||||
const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||
|
||||
hn::Transform(d, x, size, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
|
||||
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
|
||||
// If we do not subtract the max as in softmax, values > 100 (which do occur)
|
||||
// will all saturate to the cap, and then this function is no longer
|
||||
// monotonic, which would change the results on TopK.
|
||||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
V vmax = vmin;
|
||||
Foreach(d, x, max_pos, vmin,
|
||||
[&vmax](const auto d, const auto value)
|
||||
HWY_ATTR { vmax = hn::Max(vmax, value); });
|
||||
vmax = hn::MaxOfLanes(d, vmax);
|
||||
|
||||
// We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to
|
||||
// v * vinv_cap + (-vmax*vinv_cap).
|
||||
const V add = hn::Neg(hn::Mul(vmax, vinv_cap));
|
||||
|
||||
hn::Transform(
|
||||
d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec<D> v) HWY_ATTR {
|
||||
return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add)));
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue