mirror of https://github.com/google/gemma.cpp.git
Small cleanups. Fixes gemma_test build.
PiperOrigin-RevId: 649008524
This commit is contained in:
parent
7e4b20455e
commit
a40165dea2
|
|
@ -153,12 +153,8 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
|||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
// fprintf(stderr, "Query %zu returned token \"%s\"\n\n", query_index,
|
||||
// token_text.c_str());
|
||||
std::string single_res = res[query_index].first + token_text;
|
||||
size_t current_token_count = res[query_index].second + 1;
|
||||
res[query_index] = std::make_pair(single_res, current_token_count);
|
||||
|
||||
res[query_index].first.append(token_text);
|
||||
res[query_index].second += 1;
|
||||
++total_tokens;
|
||||
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
|
|
@ -185,21 +181,20 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
|
|||
/*pos=*/0, input);
|
||||
return QueryModel(prompt);
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||
const std::vector<std::string>& inputs) {
|
||||
std::vector<std::unique_ptr<std::vector<int>>> prompts;
|
||||
std::vector<std::vector<int>> prompts;
|
||||
prompts.reserve(inputs.size());
|
||||
for (auto& input : inputs) {
|
||||
std::string mutable_prompt = input;
|
||||
prompts.push_back(std::make_unique<std::vector<int>>(
|
||||
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
||||
/*pos=*/0, mutable_prompt)));
|
||||
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
||||
/*pos=*/0, mutable_prompt));
|
||||
}
|
||||
std::vector<hwy::Span<int>> prompt_vector;
|
||||
prompt_vector.reserve(prompts.size());
|
||||
for (auto& prompt : prompts) {
|
||||
prompt_vector.push_back(hwy::Span<int>(
|
||||
prompt->data(), prompt->size()));
|
||||
prompt_vector.push_back(hwy::Span<int>(prompt.data(), prompt.size()));
|
||||
}
|
||||
hwy::Span<const hwy::Span<int>> prompt_span = hwy::Span<const hwy::Span<int>>(
|
||||
prompt_vector.data(), prompt_vector.size());
|
||||
|
|
|
|||
|
|
@ -105,8 +105,7 @@ struct Activations {
|
|||
std::array<float, kBatchSize * kGriffinDim> griffin_x;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_y;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_gate_x;
|
||||
std::array<float, kBatchSize * kGriffinDim>
|
||||
griffin_multiplier;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
|
|
@ -473,7 +472,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
|
||||
layer_weights->linear_w.data(),
|
||||
activations.ffw_out.data(), pool);
|
||||
} else {
|
||||
} else { // TConfig::kFFBiases == true
|
||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
|
||||
const hwy::bfloat16_t* HWY_RESTRICT vec =
|
||||
|
|
@ -483,14 +482,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
|
||||
PROFILER_ZONE("Gen.FFW.GatedGELU");
|
||||
// Same matrix, first and second half of rows. Could fuse into one MatVec.
|
||||
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||
MatVecT</*kAdd=*/true, kFFHiddenDim, kModelDim>(
|
||||
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
|
||||
TConfig::kFFBiases
|
||||
? layer_weights->ffw_gating_biases.data() + kFFHiddenDim
|
||||
: nullptr,
|
||||
even_odd, out_mul, pool);
|
||||
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd,
|
||||
out_mul, pool);
|
||||
// Gate, will go through the nonlinearity.
|
||||
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||
MatVecT</*kAdd=*/true, kFFHiddenDim, kModelDim>(
|
||||
layer_weights->gating_einsum_w, 0, vec,
|
||||
layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
|
||||
|
||||
|
|
@ -501,7 +498,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
[](DF df, VF v, VF mul)
|
||||
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
|
||||
|
||||
MatVecT</*kAdd=*/TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
|
||||
MatVecT</*kAdd=*/true, kModelDim, kFFHiddenDim>(
|
||||
layer_weights->linear_w, 0,
|
||||
activations.ffw_hidden.data() + hidden_offset,
|
||||
layer_weights->ffw_output_biases.data(), even_odd,
|
||||
|
|
|
|||
|
|
@ -38,9 +38,8 @@ namespace gcpp {
|
|||
// true to continue generation.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
||||
// For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
// true to continue generation.
|
||||
// For prompt tokens, probability is 0.0f.
|
||||
// StreamFunc should return false to stop generation and true to continue.
|
||||
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||
// If not empty, AcceptFunc is called with token. It should return false for
|
||||
// tokens you don't want to generate and true for tokens you want to generate.
|
||||
|
|
|
|||
|
|
@ -172,24 +172,24 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
"The capital of Hungary is Budapest which is located in Europe.";
|
||||
float entropy = s_env->CrossEntropy(kSmall);
|
||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||
EXPECT_LT(entropy,
|
||||
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 2.1f : 2.0f);
|
||||
const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
|
||||
EXPECT_LT(entropy, is_7b ? 2.1f : 2.0f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||
if (!s_env->GetModel()) return;
|
||||
float entropy = s_env->CrossEntropy(kJingleBells);
|
||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||
EXPECT_LT(entropy,
|
||||
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.9f : 1.8f);
|
||||
const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
|
||||
EXPECT_LT(entropy, is_7b ? 0.9f : 1.8f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||
if (!s_env->GetModel()) return;
|
||||
float entropy = s_env->CrossEntropy(kGettysburg);
|
||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||
EXPECT_LT(entropy,
|
||||
(s_env->ModelType() == gcpp::Model::GEMMA_7B) ? 0.8f : 1.2f);
|
||||
const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
|
||||
EXPECT_LT(entropy, is_7b ? 0.8f : 1.2f);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
Loading…
Reference in New Issue