Small cleanups. Fixes gemma_test build.

PiperOrigin-RevId: 649008524
This commit is contained in:
Daniel Keysers 2024-07-03 03:12:57 -07:00 committed by Copybara-Service
parent 7e4b20455e
commit a40165dea2
4 changed files with 22 additions and 31 deletions

View File

@ -153,12 +153,8 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
std::string token_text;
HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
// fprintf(stderr, "Query %zu returned token \"%s\"\n\n", query_index,
// token_text.c_str());
std::string single_res = res[query_index].first + token_text;
size_t current_token_count = res[query_index].second + 1;
res[query_index] = std::make_pair(single_res, current_token_count);
res[query_index].first.append(token_text);
res[query_index].second += 1;
++total_tokens;
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
LogSpeedStats(time_start, total_tokens);
@ -185,21 +181,20 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
/*pos=*/0, input);
return QueryModel(prompt);
}
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) {
std::vector<std::unique_ptr<std::vector<int>>> prompts;
std::vector<std::vector<int>> prompts;
prompts.reserve(inputs.size());
for (auto& input : inputs) {
std::string mutable_prompt = input;
prompts.push_back(std::make_unique<std::vector<int>>(
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, mutable_prompt)));
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, mutable_prompt));
}
std::vector<hwy::Span<int>> prompt_vector;
prompt_vector.reserve(prompts.size());
for (auto& prompt : prompts) {
prompt_vector.push_back(hwy::Span<int>(
prompt->data(), prompt->size()));
prompt_vector.push_back(hwy::Span<int>(prompt.data(), prompt.size()));
}
hwy::Span<const hwy::Span<int>> prompt_span = hwy::Span<const hwy::Span<int>>(
prompt_vector.data(), prompt_vector.size());

View File

@ -105,8 +105,7 @@ struct Activations {
std::array<float, kBatchSize * kGriffinDim> griffin_x;
std::array<float, kBatchSize * kGriffinDim> griffin_y;
std::array<float, kBatchSize * kGriffinDim> griffin_gate_x;
std::array<float, kBatchSize * kGriffinDim>
griffin_multiplier;
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
};
template <typename TConfig>
@ -473,7 +472,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
layer_weights->linear_w.data(),
activations.ffw_out.data(), pool);
} else {
} else { // TConfig::kFFBiases == true
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
const hwy::bfloat16_t* HWY_RESTRICT vec =
@ -483,14 +482,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
PROFILER_ZONE("Gen.FFW.GatedGELU");
// Same matrix, first and second half of rows. Could fuse into one MatVec.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
MatVecT</*kAdd=*/true, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
TConfig::kFFBiases
? layer_weights->ffw_gating_biases.data() + kFFHiddenDim
: nullptr,
even_odd, out_mul, pool);
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd,
out_mul, pool);
// Gate, will go through the nonlinearity.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
MatVecT</*kAdd=*/true, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
@ -501,7 +498,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
[](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
MatVecT</*kAdd=*/TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
MatVecT</*kAdd=*/true, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd,

View File

@ -38,9 +38,8 @@ namespace gcpp {
// true to continue generation.
using StreamFunc = std::function<bool(int, float)>;
// BatchStreamFunc is called with (query_idx, pos, token, probability).
// For prompt tokens,
// probability is 0.0f. StreamFunc should return false to stop generation and
// true to continue generation.
// For prompt tokens, probability is 0.0f.
// StreamFunc should return false to stop generation and true to continue.
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.

View File

@ -172,24 +172,24 @@ TEST_F(GemmaTest, CrossEntropySmall) {
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
EXPECT_LT(entropy, is_7b ? 2.1f : 2.0f);
}
TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kJingleBells);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
EXPECT_LT(entropy, is_7b ? 0.9f : 1.8f);
}
TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kGettysburg);
fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy,
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
EXPECT_LT(entropy, is_7b ? 0.8f : 1.2f);
}
} // namespace