Small cleanups. Fixes gemma_test build.

PiperOrigin-RevId: 649008524
This commit is contained in:
Daniel Keysers 2024-07-03 03:12:57 -07:00 committed by Copybara-Service
parent 7e4b20455e
commit a40165dea2
4 changed files with 22 additions and 31 deletions

View File

@ -153,12 +153,8 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
std::string token_text; std::string token_text;
HWY_ASSERT( HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text)); model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
// fprintf(stderr, "Query %zu returned token \"%s\"\n\n", query_index, res[query_index].first.append(token_text);
// token_text.c_str()); res[query_index].second += 1;
std::string single_res = res[query_index].first + token_text;
size_t current_token_count = res[query_index].second + 1;
res[query_index] = std::make_pair(single_res, current_token_count);
++total_tokens; ++total_tokens;
if (app_.verbosity >= 1 && total_tokens % 128 == 0) { if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
@ -185,21 +181,20 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
/*pos=*/0, input); /*pos=*/0, input);
return QueryModel(prompt); return QueryModel(prompt);
} }
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel( std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) { const std::vector<std::string>& inputs) {
std::vector<std::unique_ptr<std::vector<int>>> prompts; std::vector<std::vector<int>> prompts;
prompts.reserve(inputs.size()); prompts.reserve(inputs.size());
for (auto& input : inputs) { for (auto& input : inputs) {
std::string mutable_prompt = input; std::string mutable_prompt = input;
prompts.push_back(std::make_unique<std::vector<int>>( prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
WrapAndTokenize(model_->Tokenizer(), model_->Info(), /*pos=*/0, mutable_prompt));
/*pos=*/0, mutable_prompt)));
} }
std::vector<hwy::Span<int>> prompt_vector; std::vector<hwy::Span<int>> prompt_vector;
prompt_vector.reserve(prompts.size()); prompt_vector.reserve(prompts.size());
for (auto& prompt : prompts) { for (auto& prompt : prompts) {
prompt_vector.push_back(hwy::Span<int>( prompt_vector.push_back(hwy::Span<int>(prompt.data(), prompt.size()));
prompt->data(), prompt->size()));
} }
hwy::Span<const hwy::Span<int>> prompt_span = hwy::Span<const hwy::Span<int>>( hwy::Span<const hwy::Span<int>> prompt_span = hwy::Span<const hwy::Span<int>>(
prompt_vector.data(), prompt_vector.size()); prompt_vector.data(), prompt_vector.size());

View File

@ -105,8 +105,7 @@ struct Activations {
std::array<float, kBatchSize * kGriffinDim> griffin_x; std::array<float, kBatchSize * kGriffinDim> griffin_x;
std::array<float, kBatchSize * kGriffinDim> griffin_y; std::array<float, kBatchSize * kGriffinDim> griffin_y;
std::array<float, kBatchSize * kGriffinDim> griffin_gate_x; std::array<float, kBatchSize * kGriffinDim> griffin_gate_x;
std::array<float, kBatchSize * kGriffinDim> std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
griffin_multiplier;
}; };
template <typename TConfig> template <typename TConfig>
@ -473,7 +472,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(), MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
layer_weights->linear_w.data(), layer_weights->linear_w.data(),
activations.ffw_out.data(), pool); activations.ffw_out.data(), pool);
} else { } else { // TConfig::kFFBiases == true
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
const hwy::bfloat16_t* HWY_RESTRICT vec = const hwy::bfloat16_t* HWY_RESTRICT vec =
@ -483,14 +482,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
PROFILER_ZONE("Gen.FFW.GatedGELU"); PROFILER_ZONE("Gen.FFW.GatedGELU");
// Same matrix, first and second half of rows. Could fuse into one MatVec. // Same matrix, first and second half of rows. Could fuse into one MatVec.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>( MatVecT</*kAdd=*/true, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
TConfig::kFFBiases layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd,
? layer_weights->ffw_gating_biases.data() + kFFHiddenDim out_mul, pool);
: nullptr,
even_odd, out_mul, pool);
// Gate, will go through the nonlinearity. // Gate, will go through the nonlinearity.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>( MatVecT</*kAdd=*/true, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec, layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), even_odd, out, pool); layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
@ -501,7 +498,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
[](DF df, VF v, VF mul) [](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
MatVecT</*kAdd=*/TConfig::kFFBiases, kModelDim, kFFHiddenDim>( MatVecT</*kAdd=*/true, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0, layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset, activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd, layer_weights->ffw_output_biases.data(), even_odd,

View File

@ -38,9 +38,8 @@ namespace gcpp {
// true to continue generation. // true to continue generation.
using StreamFunc = std::function<bool(int, float)>; using StreamFunc = std::function<bool(int, float)>;
// BatchStreamFunc is called with (query_idx, pos, token, probability). // BatchStreamFunc is called with (query_idx, pos, token, probability).
// For prompt tokens, // For prompt tokens, probability is 0.0f.
// probability is 0.0f. StreamFunc should return false to stop generation and // StreamFunc should return false to stop generation and true to continue.
// true to continue generation.
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>; using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for // If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate. // tokens you don't want to generate and true for tokens you want to generate.

View File

@ -172,24 +172,24 @@ TEST_F(GemmaTest, CrossEntropySmall) {
"The capital of Hungary is Budapest which is located in Europe."; "The capital of Hungary is Budapest which is located in Europe.";
float entropy = s_env->CrossEntropy(kSmall); float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-byte entropy: %f\n", entropy); fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f); EXPECT_LT(entropy, is_7b ? 2.1f : 2.0f);
} }
TEST_F(GemmaTest, CrossEntropyJingleBells) { TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_env->GetModel()) return; if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kJingleBells); float entropy = s_env->CrossEntropy(kJingleBells);
fprintf(stderr, "per-byte entropy: %f\n", entropy); fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f); EXPECT_LT(entropy, is_7b ? 0.9f : 1.8f);
} }
TEST_F(GemmaTest, CrossEntropyGettysburg) { TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_env->GetModel()) return; if (!s_env->GetModel()) return;
float entropy = s_env->CrossEntropy(kGettysburg); float entropy = s_env->CrossEntropy(kGettysburg);
fprintf(stderr, "per-byte entropy: %f\n", entropy); fprintf(stderr, "per-byte entropy: %f\n", entropy);
EXPECT_LT(entropy, const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f); EXPECT_LT(entropy, is_7b ? 0.8f : 1.2f);
} }
} // namespace } // namespace