mirror of https://github.com/google/gemma.cpp.git
Merge 139a8e0964 into 529c201eb6
This commit is contained in:
commit
f78c956ecb
|
|
@ -46,11 +46,18 @@ static const char* kQuestions =
|
|||
"Which people first proposed the quark model of hadrons, and when?";
|
||||
|
||||
// All phrases in kAnswers must appear in the response in the order given for
|
||||
// the test to pass.
|
||||
static const char* kAnswers[] = {
|
||||
"a ship's anchor", "a dark forest", "an hour",
|
||||
"enormous sand", "castles", "limpet shells",
|
||||
"Murray Gell-Mann", "George Zweig", "1964"};
|
||||
// the test to pass. Multiple acceptable answers can be provided for each
|
||||
// expected phrase.
|
||||
static const std::vector<std::vector<const char*>> kAnswers = {
|
||||
{"rusty metal", "ship's anchor"},
|
||||
{"dark forest"},
|
||||
{"an hour"},
|
||||
{"enormous sand"},
|
||||
{"castles"},
|
||||
{"limpet shells"},
|
||||
{"Murray Gell-Mann"},
|
||||
{"George Zweig"},
|
||||
{"1964"}};
|
||||
|
||||
std::string LoadPromptFile(const std::string& filename) {
|
||||
// If the filename is empty, return an empty string.
|
||||
|
|
@ -108,12 +115,22 @@ class GemmaTest : public ::testing::Test {
|
|||
void TestExpectations(const std::string& response) {
|
||||
fprintf(stderr, "Response: '%s'\n", response.c_str());
|
||||
size_t pos = 0;
|
||||
for (const char* answer : kAnswers) {
|
||||
auto found = response.find(answer, pos);
|
||||
EXPECT_NE(found, std::string::npos)
|
||||
<< "Response does not contain " << answer;
|
||||
if (found != std::string::npos) {
|
||||
pos = found + strlen(answer);
|
||||
for (const auto& answer_group : kAnswers) {
|
||||
size_t earliest_pos = std::string::npos;
|
||||
const char* matched_answer = nullptr;
|
||||
for (const char* answer : answer_group) {
|
||||
auto found = response.find(answer, pos);
|
||||
if (found != std::string::npos &&
|
||||
(earliest_pos == std::string::npos || found < earliest_pos)) {
|
||||
earliest_pos = found;
|
||||
matched_answer = answer;
|
||||
}
|
||||
}
|
||||
EXPECT_NE(earliest_pos, std::string::npos)
|
||||
<< "Response does not contain acceptable answers, e.g., "
|
||||
<< answer_group[0];
|
||||
if (earliest_pos != std::string::npos) {
|
||||
pos = earliest_pos + strlen(matched_answer);
|
||||
}
|
||||
}
|
||||
s_env->PrintProfileResults();
|
||||
|
|
|
|||
|
|
@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos,
|
|||
}
|
||||
float m = hn::ReduceMax(df, x);
|
||||
m = std::max(m, old_max);
|
||||
x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
|
||||
x = hn::FastExpMinusOrZero(df, hn::Sub(x, hn::Set(df, m)));
|
||||
float scale = old_d * std::exp(old_max - m);
|
||||
old_d = hn::ReduceSum(df, x) + scale;
|
||||
old_max = m;
|
||||
|
|
@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
|
|||
float m = hn::ReduceMax(df, x_max);
|
||||
m = std::max(m, old_max);
|
||||
VF m_vec = hn::Set(df, m);
|
||||
x0 = hn::Exp(df, hn::Sub(x0, m_vec));
|
||||
x1 = hn::Exp(df, hn::Sub(x1, m_vec));
|
||||
x0 = hn::FastExpMinusOrZero(df, hn::Sub(x0, m_vec));
|
||||
x1 = hn::FastExpMinusOrZero(df, hn::Sub(x1, m_vec));
|
||||
float scale = old_d * std::exp(old_max - m);
|
||||
VF x_sum = hn::Add(x0, x1);
|
||||
old_d = hn::ReduceSum(df, x_sum) + scale;
|
||||
|
|
@ -672,7 +672,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
|
|||
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
|
||||
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
|
||||
}
|
||||
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
|
||||
VF4 scale = hn::Mul(
|
||||
old_d_vf, hn::FastExpMinusOrZero(df4, hn::Sub(old_max_vf, new_max)));
|
||||
old_d_vf = hn::Add(scale, x_sum);
|
||||
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
|
||||
const VF zero = hn::Zero(df);
|
||||
|
|
@ -810,7 +811,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
|
|||
x_6_sum, x_7_sum,
|
||||
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
|
||||
}
|
||||
VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max)));
|
||||
VF8 scale = hn::Mul(
|
||||
old_d_vf, hn::FastExpMinusOrZero(df8, hn::Sub(old_max_vf, new_max)));
|
||||
old_d_vf = hn::Add(scale, x_sum);
|
||||
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f));
|
||||
const VF zero = hn::Zero(df);
|
||||
|
|
|
|||
Loading…
Reference in New Issue