mirror of https://github.com/google/gemma.cpp.git
Fix numerical issue in Softcap by subtracting max.
Also update test threshold. PiperOrigin-RevId: 642587468
This commit is contained in:
parent
e37447cfe2
commit
2a0e6ee976
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
|
|
@ -319,12 +320,24 @@ static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward,
|
||||||
const auto vcap = hn::Set(d, cap);
|
const auto vcap = hn::Set(d, cap);
|
||||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||||
|
|
||||||
|
// TODO(szabadka): Investigate what to do when the argmax is not unique.
|
||||||
|
// TODO(szabadka): Use IndexOfMax from hwy when it is available.
|
||||||
|
size_t imax = std::max_element(forward, forward + size) - forward;
|
||||||
|
|
||||||
hn::Transform1(
|
hn::Transform1(
|
||||||
d, backward, size, forward,
|
d, backward, size, forward,
|
||||||
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
||||||
const auto scaled = hn::Mul(vinv_cap, y);
|
const auto scaled = hn::Mul(vinv_cap, y);
|
||||||
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
backward[imax] = 0;
|
||||||
|
auto sum = hn::Zero(d);
|
||||||
|
Foreach(d, backward, size, sum,
|
||||||
|
[&sum](const auto d, const auto value) HWY_ATTR {
|
||||||
|
sum = hn::Add(sum, value);
|
||||||
|
});
|
||||||
|
backward[imax] = -hn::ReduceSum(d, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void CrossEntropyLossGrad(
|
static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||||
|
|
|
||||||
|
|
@ -277,12 +277,19 @@ void LayerVJP(const Layer<T, TConfig>& weights,
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void SoftcapVJPT(const T* y, T* dy, size_t N) {
|
void SoftcapVJPT(const T* y, T* dy, size_t N) {
|
||||||
|
size_t imax = std::max_element(y, y + N) - y;
|
||||||
T cap = 30.0;
|
T cap = 30.0;
|
||||||
T inv_cap = T(1.0) / cap;
|
T inv_cap = T(1.0) / cap;
|
||||||
for (size_t i = 0; i < N; ++i) {
|
for (size_t i = 0; i < N; ++i) {
|
||||||
T scaled = y[i] * inv_cap;
|
T scaled = y[i] * inv_cap;
|
||||||
dy[i] *= (T(1.0) - scaled * scaled);
|
dy[i] *= (T(1.0) - scaled * scaled);
|
||||||
}
|
}
|
||||||
|
dy[imax] = T(0.0);
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
if (i != imax) {
|
||||||
|
dy[imax] -= dy[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
|
@ -318,8 +325,11 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
|
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
|
||||||
kVocabSize, num_tokens);
|
kVocabSize, num_tokens);
|
||||||
|
|
||||||
SoftcapVJPT(forward.logits.data(), backward.logits.data(),
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
num_tokens * kVocabSize);
|
SoftcapVJPT(forward.logits.data() + i * kVocabSize,
|
||||||
|
backward.logits.data() + i * kVocabSize,
|
||||||
|
kVocabSize);
|
||||||
|
}
|
||||||
|
|
||||||
MatMulVJPT(weights.embedder_input_embedding.data(),
|
MatMulVJPT(weights.embedder_input_embedding.data(),
|
||||||
forward.final_norm_output.data(),
|
forward.final_norm_output.data(),
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ TEST(BackPropTest, SoftcapVJP) {
|
||||||
Softcap(x.data(), N);
|
Softcap(x.data(), N);
|
||||||
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
|
memcpy(dx.data(), dy.data(), N * sizeof(dx[0]));
|
||||||
SoftcapVJPT(x.data(), dx.data(), N);
|
SoftcapVJPT(x.data(), dx.data(), N);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -95,10 +95,19 @@ void Softmax(T* x, size_t N, size_t K) {
|
||||||
}
|
}
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void Softcap(T* x, size_t N) {
|
void Softcap(T* x, size_t N) {
|
||||||
|
auto maxreal = std::real(x[0]);
|
||||||
|
size_t imax = 0;
|
||||||
|
for (size_t i = 1; i < N; ++i) {
|
||||||
|
if (std::real(x[i]) > maxreal) {
|
||||||
|
maxreal = std::real(x[i]);
|
||||||
|
imax = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
T cap = 30.0;
|
T cap = 30.0;
|
||||||
T inv_cap = T(1.0) / cap;
|
T inv_cap = T(1.0) / cap;
|
||||||
|
T xmax = x[imax];
|
||||||
for (size_t i = 0; i < N; ++i) {
|
for (size_t i = 0; i < N; ++i) {
|
||||||
x[i] = cap * std::tanh(x[i] * inv_cap);
|
x[i] = cap * std::tanh((x[i] - xmax) * inv_cap);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -275,7 +284,9 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||||
forward.final_norm_output.data(),
|
forward.final_norm_output.data(),
|
||||||
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
|
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
|
||||||
|
|
||||||
Softcap(forward.logits.data(), num_tokens * kVocabSize);
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
|
Softcap(forward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||||
|
}
|
||||||
|
|
||||||
memcpy(forward.probs.data(), forward.logits.data(),
|
memcpy(forward.probs.data(), forward.logits.data(),
|
||||||
num_tokens * kVocabSize * sizeof(forward.logits[0]));
|
num_tokens * kVocabSize * sizeof(forward.logits[0]));
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
// Placeholder for internal header, do not modify.
|
||||||
|
#include "gemma/common.h"
|
||||||
#include "gemma/cross_entropy.h"
|
#include "gemma/cross_entropy.h"
|
||||||
#include "gemma/ops.h"
|
#include "gemma/ops.h"
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
|
|
@ -38,20 +39,22 @@ char** s_argv = nullptr;
|
||||||
class GemmaTest : public ::testing::Test {
|
class GemmaTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
static void SetUpTestSuite() {
|
static void SetUpTestSuite() {
|
||||||
gcpp::LoaderArgs loader(s_argc, s_argv);
|
|
||||||
gcpp::AppArgs app(s_argc, s_argv);
|
gcpp::AppArgs app(s_argc, s_argv);
|
||||||
|
gcpp::LoaderArgs loader(s_argc, s_argv);
|
||||||
if (const char* err = loader.Validate()) {
|
if (const char* err = loader.Validate()) {
|
||||||
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n");
|
fprintf(stderr, "Insufficient LoaderArgs, skipping e2e tests.\n");
|
||||||
} else {
|
} else {
|
||||||
|
fprintf(stderr, "Loading model..\n");
|
||||||
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads);
|
s_pool = std::make_unique<hwy::ThreadPool>(app.num_threads);
|
||||||
s_model = AllocateGemma(loader, *s_pool);
|
s_gemma = AllocateGemma(loader, *s_pool);
|
||||||
s_kv_cache = KVCache::Create(loader.ModelType());
|
s_kv_cache = KVCache::Create(loader.ModelType());
|
||||||
|
s_model = loader.ModelType();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void TearDownTestSuite() {
|
static void TearDownTestSuite() {
|
||||||
s_pool.reset();
|
s_pool.reset();
|
||||||
s_model.reset();
|
s_gemma.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GemmaReply(const std::string& prompt_string) {
|
std::string GemmaReply(const std::string& prompt_string) {
|
||||||
|
|
@ -59,7 +62,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
gen.seed(42);
|
gen.seed(42);
|
||||||
|
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt));
|
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
|
||||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||||
// if needed.
|
// if needed.
|
||||||
prompt.insert(prompt.begin(), BOS_ID);
|
prompt.insert(prompt.begin(), BOS_ID);
|
||||||
|
|
@ -78,25 +81,25 @@ class GemmaTest : public ::testing::Test {
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
};
|
};
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info;
|
||||||
s_model->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
|
s_gemma->Generate(runtime_config, prompt, /*start_pos=*/0, s_kv_cache,
|
||||||
timing_info, /*layers_output=*/nullptr);
|
timing_info, /*layers_output=*/nullptr);
|
||||||
std::string response_text;
|
std::string response_text;
|
||||||
HWY_ASSERT(s_model->Tokenizer().Decode(response, &response_text));
|
HWY_ASSERT(s_gemma->Tokenizer().Decode(response, &response_text));
|
||||||
return response_text;
|
return response_text;
|
||||||
}
|
}
|
||||||
|
|
||||||
float GemmaCrossEntropy(const std::string& prompt_string) {
|
float GemmaCrossEntropy(const std::string& prompt_string) {
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(s_model->Tokenizer().Encode(prompt_string, &prompt));
|
HWY_ASSERT(s_gemma->Tokenizer().Encode(prompt_string, &prompt));
|
||||||
prompt.insert(prompt.begin(), BOS_ID);
|
prompt.insert(prompt.begin(), BOS_ID);
|
||||||
return ComputeCrossEntropy(*s_model, /*max_tokens=*/3072, prompt,
|
return ComputeCrossEntropy(*s_gemma, /*max_tokens=*/3072, prompt,
|
||||||
s_kv_cache,
|
s_kv_cache,
|
||||||
/*verbosity=*/0) /
|
/*verbosity=*/0) /
|
||||||
prompt_string.size();
|
prompt_string.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||||
if (!s_model) return;
|
if (!s_gemma) return;
|
||||||
for (size_t i = 0; i < num_questions; ++i) {
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
fprintf(stderr, "Question %zu\n\n", i + 1);
|
fprintf(stderr, "Question %zu\n\n", i + 1);
|
||||||
std::string response = GemmaReply(kQA[i][0]);
|
std::string response = GemmaReply(kQA[i][0]);
|
||||||
|
|
@ -106,13 +109,15 @@ class GemmaTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unique_ptr<hwy::ThreadPool> s_pool;
|
static std::unique_ptr<hwy::ThreadPool> s_pool;
|
||||||
static std::unique_ptr<gcpp::Gemma> s_model;
|
static std::unique_ptr<gcpp::Gemma> s_gemma;
|
||||||
static gcpp::KVCache s_kv_cache;
|
static gcpp::KVCache s_kv_cache;
|
||||||
|
static gcpp::Model s_model;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
|
/*static*/ std::unique_ptr<hwy::ThreadPool> GemmaTest::s_pool;
|
||||||
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_model;
|
/*static*/ std::unique_ptr<gcpp::Gemma> GemmaTest::s_gemma;
|
||||||
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
|
/*static*/ gcpp::KVCache GemmaTest::s_kv_cache;
|
||||||
|
/*static*/ gcpp::Model GemmaTest::s_model;
|
||||||
|
|
||||||
TEST_F(GemmaTest, Geography) {
|
TEST_F(GemmaTest, Geography) {
|
||||||
static const char* kQA[][2] = {
|
static const char* kQA[][2] = {
|
||||||
|
|
@ -176,27 +181,26 @@ static const char kGettysburg[] = {
|
||||||
"people, for the people, shall not perish from the earth.\n"};
|
"people, for the people, shall not perish from the earth.\n"};
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
if (!s_model) return;
|
if (!s_gemma) return;
|
||||||
static const char kSmall[] =
|
static const char kSmall[] =
|
||||||
"The capital of Hungary is Budapest which is located in Europe.";
|
"The capital of Hungary is Budapest which is located in Europe.";
|
||||||
float entropy = GemmaCrossEntropy(kSmall);
|
float entropy = GemmaCrossEntropy(kSmall);
|
||||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||||
// Note that entropy is 3x higher for the 7b-it model.
|
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
|
||||||
EXPECT_LT(entropy, 1.7f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||||
if (!s_model) return;
|
if (!s_gemma) return;
|
||||||
float entropy = GemmaCrossEntropy(kJingleBells);
|
float entropy = GemmaCrossEntropy(kJingleBells);
|
||||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||||
EXPECT_LT(entropy, 1.7f);
|
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||||
if (!s_model) return;
|
if (!s_gemma) return;
|
||||||
float entropy = GemmaCrossEntropy(kGettysburg);
|
float entropy = GemmaCrossEntropy(kGettysburg);
|
||||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||||
EXPECT_LT(entropy, 1.2f);
|
EXPECT_LT(entropy, (s_model == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
24
gemma/ops.h
24
gemma/ops.h
|
|
@ -1384,12 +1384,28 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
const D d;
|
const D d;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
|
||||||
const auto vcap = hn::Set(d, cap);
|
const V vcap = hn::Set(d, cap);
|
||||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||||
|
|
||||||
hn::Transform(d, x, size, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
|
// If we do not subtract the max as in softmax, values > 100 (which do occur)
|
||||||
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
|
// will all saturate to the cap, and then this function is no longer
|
||||||
|
// monotonic, which would change the results on TopK.
|
||||||
|
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||||
|
V vmax = vmin;
|
||||||
|
Foreach(d, x, max_pos, vmin,
|
||||||
|
[&vmax](const auto d, const auto value)
|
||||||
|
HWY_ATTR { vmax = hn::Max(vmax, value); });
|
||||||
|
vmax = hn::MaxOfLanes(d, vmax);
|
||||||
|
|
||||||
|
// We want (v-vmax) * vinv_cap. To take advantage of FMA, multiply this out to
|
||||||
|
// v * vinv_cap + (-vmax*vinv_cap).
|
||||||
|
const V add = hn::Neg(hn::Mul(vmax, vinv_cap));
|
||||||
|
|
||||||
|
hn::Transform(
|
||||||
|
d, x, size, [&vcap, &vinv_cap, &add](D d, hn::Vec<D> v) HWY_ATTR {
|
||||||
|
return hn::Mul(vcap, hn::Tanh(d, hn::MulAdd(v, vinv_cap, add)));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue