diff --git a/gemma/benchmark_helper.cc b/gemma/benchmark_helper.cc index a5d994c..58da6a0 100644 --- a/gemma/benchmark_helper.cc +++ b/gemma/benchmark_helper.cc @@ -153,12 +153,8 @@ std::vector> GemmaEnv::BatchQueryModel2( std::string token_text; HWY_ASSERT( model_->Tokenizer().Decode(std::vector{token}, &token_text)); - // fprintf(stderr, "Query %zu returned token \"%s\"\n\n", query_index, - // token_text.c_str()); - std::string single_res = res[query_index].first + token_text; - size_t current_token_count = res[query_index].second + 1; - res[query_index] = std::make_pair(single_res, current_token_count); - + res[query_index].first.append(token_text); + res[query_index].second += 1; ++total_tokens; if (app_.verbosity >= 1 && total_tokens % 128 == 0) { LogSpeedStats(time_start, total_tokens); @@ -185,21 +181,20 @@ std::pair GemmaEnv::QueryModel(std::string& input) { /*pos=*/0, input); return QueryModel(prompt); } + std::vector> GemmaEnv::BatchQueryModel( const std::vector& inputs) { - std::vector>> prompts; + std::vector> prompts; prompts.reserve(inputs.size()); for (auto& input : inputs) { std::string mutable_prompt = input; - prompts.push_back(std::make_unique>( - WrapAndTokenize(model_->Tokenizer(), model_->Info(), - /*pos=*/0, mutable_prompt))); + prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(), + /*pos=*/0, mutable_prompt)); } std::vector> prompt_vector; prompt_vector.reserve(prompts.size()); for (auto& prompt : prompts) { - prompt_vector.push_back(hwy::Span( - prompt->data(), prompt->size())); + prompt_vector.push_back(hwy::Span(prompt.data(), prompt.size())); } hwy::Span> prompt_span = hwy::Span>( prompt_vector.data(), prompt_vector.size()); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 6735d18..c82d193 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -105,8 +105,7 @@ struct Activations { std::array griffin_x; std::array griffin_y; std::array griffin_gate_x; - std::array - griffin_multiplier; + std::array griffin_multiplier; }; template @@ -473,7 +472,7 @@ HWY_NOINLINE void FFW(Activations& activations, MatMul_4x4_Batch(num_tokens, activations.C1.data(), layer_weights->linear_w.data(), activations.ffw_out.data(), pool); - } else { + } else { // TConfig::kFFBiases == true for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; const hwy::bfloat16_t* HWY_RESTRICT vec = @@ -483,14 +482,12 @@ HWY_NOINLINE void FFW(Activations& activations, PROFILER_ZONE("Gen.FFW.GatedGELU"); // Same matrix, first and second half of rows. Could fuse into one MatVec. - MatVecT( + MatVecT( layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, - TConfig::kFFBiases - ? layer_weights->ffw_gating_biases.data() + kFFHiddenDim - : nullptr, - even_odd, out_mul, pool); + layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd, + out_mul, pool); // Gate, will go through the nonlinearity. - MatVecT( + MatVecT( layer_weights->gating_einsum_w, 0, vec, layer_weights->ffw_gating_biases.data(), even_odd, out, pool); @@ -501,7 +498,7 @@ HWY_NOINLINE void FFW(Activations& activations, [](DF df, VF v, VF mul) HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); - MatVecT( + MatVecT( layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset, layer_weights->ffw_output_biases.data(), even_odd, diff --git a/gemma/gemma.h b/gemma/gemma.h index e35f0ef..73bc3dc 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -38,9 +38,8 @@ namespace gcpp { // true to continue generation. using StreamFunc = std::function; // BatchStreamFunc is called with (query_idx, pos, token, probability). -// For prompt tokens, -// probability is 0.0f. StreamFunc should return false to stop generation and -// true to continue generation. +// For prompt tokens, probability is 0.0f. +// StreamFunc should return false to stop generation and true to continue. using BatchStreamFunc = std::function; // If not empty, AcceptFunc is called with token. It should return false for // tokens you don't want to generate and true for tokens you want to generate. diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index c50760b..155577c 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -172,24 +172,24 @@ TEST_F(GemmaTest, CrossEntropySmall) { "The capital of Hungary is Budapest which is located in Europe."; float entropy = s_env->CrossEntropy(kSmall); fprintf(stderr, "per-byte entropy: %f\n", entropy); - EXPECT_LT(entropy, - (s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f); + const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B; + EXPECT_LT(entropy, is_7b ? 2.1f : 2.0f); } TEST_F(GemmaTest, CrossEntropyJingleBells) { if (!s_env->GetModel()) return; float entropy = s_env->CrossEntropy(kJingleBells); fprintf(stderr, "per-byte entropy: %f\n", entropy); - EXPECT_LT(entropy, - (s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f); + const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B; + EXPECT_LT(entropy, is_7b ? 0.9f : 1.8f); } TEST_F(GemmaTest, CrossEntropyGettysburg) { if (!s_env->GetModel()) return; float entropy = s_env->CrossEntropy(kGettysburg); fprintf(stderr, "per-byte entropy: %f\n", entropy); - EXPECT_LT(entropy, - (s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f); + const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B; + EXPECT_LT(entropy, is_7b ? 0.8f : 1.2f); } } // namespace