diff --git a/BUILD.bazel b/BUILD.bazel index 81e6a46..b3fd70d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -107,6 +107,19 @@ cc_library( ], ) +cc_library( + name = "cross_entropy", + srcs = [ + "gemma/cross_entropy.cc", + ], + hdrs = [ + "gemma/cross_entropy.h", + ], + deps = [ + ":gemma_lib", + ], +) + cc_library( name = "args", hdrs = ["util/args.h"], @@ -141,6 +154,7 @@ cc_test( ], deps = [ ":args", + ":cross_entropy", ":gemma_lib", ":ops", "@googletest//:gtest_main", @@ -190,6 +204,7 @@ cc_binary( ":app", ":args", ":common", + ":cross_entropy", ":gemma_lib", "//compression:io", "@hwy//:hwy", diff --git a/CMakeLists.txt b/CMakeLists.txt index 72409f6..6a24ba0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,8 @@ set(SOURCES gemma/activations.h gemma/common.cc gemma/common.h + gemma/cross_entropy.cc + gemma/cross_entropy.h gemma/gemma.cc gemma/gemma.h gemma/ops.h diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 0d89416..3698ceb 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) { }; RuntimeConfig runtime = { max_tokens, max_generated_tokens, temperature, verbosity, &gen, - stream_token, accept_token, ReverseSequenceSampler::kEndToken, + stream_token, accept_token, nullptr, ReverseSequenceSampler::kEndToken, }; TimingInfo timing_info; gemma.Generate(runtime, prompt, 0, kv_cache, timing_info); diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index 1d41153..2f0b132 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -10,6 +10,7 @@ #include #include "compression/io.h" // Path +#include "gemma/cross_entropy.h" #include "gemma/gemma.h" #include "util/app.h" #include "util/args.h" @@ -204,13 +205,15 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type); - float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice, - kv_cache, app.verbosity); + float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice, + kv_cache, app.verbosity); total_entropy += entropy; LogSpeedStats(time_start, pos + num_tokens); std::string text_slice; HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice)); total_input_len += text_slice.size(); + printf("Total cross entropy: %f [cumulative: %f]\n", + entropy, total_entropy); printf("Cross entropy per byte: %f [cumulative: %f]\n", entropy / text_slice.size(), total_entropy / total_input_len); } diff --git a/gemma/cross_entropy.cc b/gemma/cross_entropy.cc new file mode 100644 index 0000000..453e8e7 --- /dev/null +++ b/gemma/cross_entropy.cc @@ -0,0 +1,109 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/cross_entropy.h" + +#include +#include +#include // NOLINT +#include +#include +#include + +namespace gcpp { + +namespace { +template +struct GetVocabSize { + int operator()() const { return TConfig::kVocabSize; } +}; + +static std::string TokenString(const GemmaTokenizer& tokenizer, int token) { + std::string token_str; + tokenizer.Decode({token}, &token_str); + return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'"; +} + +void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len, + size_t k) { + std::vector> sorted(len); + for (size_t i = 0; i < len; ++i) { + sorted[i] = std::make_pair(dist[i], static_cast(i)); + } + std::sort(sorted.begin(), sorted.end(), + [](const std::pair& a, const std::pair& b) { + if (a.first != b.first) { + return a.first > b.first; + } + return a.second < b.second; + }); + for (size_t i = 0; i < k; ++i) { + printf(" [#%-2d token %6d = %-12s %.2e]\n", static_cast(i + 1), + sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(), + sorted[i].first); + } +} +} // namespace + +float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, + const std::vector& prompt, + KVCache& kv_cache, int verbosity) { + auto stream_token = [](int, float) { return true; }; + auto accept_token = [](int) { return true; }; + + const int vocab_size = CallFunctorForModel(gemma.ModelType()); + float cross_entropy = std::log(vocab_size); // first token + size_t pos = 1; + std::function sample_token = + [&](const float* probs, size_t vocab_size) -> int { + const int token = prompt[pos]; + const float prob = probs[token]; + cross_entropy -= std::max(std::log(prob), -64.0f); + + if (verbosity >= 4) { + LogTopK(gemma.Tokenizer(), probs, vocab_size, 10); + } + if (verbosity >= 3) { + printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, + token, TokenString(gemma.Tokenizer(), token).c_str(), prob, + -std::log(prob) / std::log(2.0)); + } + if (verbosity >= 2 && pos % 100 == 99) { + printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, + cross_entropy / std::log(2.0) / (pos + 1)); + } + ++pos; + return token; + }; + std::vector prompt0 = { prompt[0] }; + RuntimeConfig runtime = { + .max_tokens = max_tokens, + .max_generated_tokens = max_tokens - 1, + .temperature = 0.0f, + .verbosity = verbosity, + .gen = nullptr, + .stream_token = stream_token, + .accept_token = accept_token, + .sample_func = &sample_token, + }; + TimingInfo timing_info; + + gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr); + + const float scale = 1.0 / std::log(2.0); + return cross_entropy * scale; +} + +} // namespace gcpp diff --git a/gemma/cross_entropy.h b/gemma/cross_entropy.h new file mode 100644 index 0000000..e6bac00 --- /dev/null +++ b/gemma/cross_entropy.h @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ + +#include + +#include "gemma/gemma.h" + +namespace gcpp { + +float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, + const std::vector& prompt, KVCache& kv_cache, + int verbosity); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 23d7922..b6a4c21 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -40,7 +40,6 @@ #include #include #include -#include // NOLINT #include #include #include @@ -217,32 +216,6 @@ bool GemmaTokenizer::Decode(const std::vector& ids, return impl_->Decode(ids, detokenized); } -static std::string TokenString(const GemmaTokenizer& tokenizer, int token) { - std::string token_str; - tokenizer.Decode({token}, &token_str); - return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'"; -} - -void LogTopK(const GemmaTokenizer& tokenizer, float* HWY_RESTRICT logits, - float* HWY_RESTRICT dist, size_t len, size_t k) { - std::vector> sorted(len); - for (size_t i = 0; i < len; ++i) { - sorted[i] = std::make_pair(dist[i], static_cast(i)); - } - std::sort(sorted.begin(), sorted.end(), - [](const std::pair& a, const std::pair& b) { - if (a.first != b.first) { - return a.first > b.first; - } - return a.second < b.second; - }); - for (size_t i = 0; i < k; ++i) { - printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast(i + 1), - sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(), - sorted[i].first, logits[sorted[i].second]); - } -} - } // namespace gcpp #endif // GEMMA_ONCE @@ -837,13 +810,19 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, MatVec( weights.embedder_input_embedding, 0, final_activation, activations.even_odd.data(), activations.logits.data(), pool); + LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); - token = SampleTopK( - activations.logits.data(), kVocabSize, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token); - if (!runtime_config.stream_token(token, activations.logits[token])) { - token = runtime_config.eos_id; + if (runtime_config.sample_func) { + token = (*runtime_config.sample_func)(activations.logits.data(), + kVocabSize); + } else { + token = SampleTopK( + activations.logits.data(), kVocabSize, *runtime_config.gen, + runtime_config.temperature, runtime_config.accept_token); + if (!runtime_config.stream_token(token, activations.logits[token])) { + token = runtime_config.eos_id; + } } if (generate_pos == 0) { timing_info.time_to_first_token = hwy::platform::Now() - gen_start; @@ -868,51 +847,6 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, } } -template -void ComputeCrossEntropy(const ByteStorageT& weights_u8, - ByteStorageT& decode_u8, - const GemmaTokenizer& tokenizer, size_t max_tokens, - const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, int verbosity, - float& cross_entropy) { - const WeightsT& weights = GetWeights(weights_u8); - auto& activations = GetActivations(decode_u8); - - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kVocabSize = TConfig::kVocabSize; - std::vector logits(kVocabSize); - Softmax(activations.logits.data(), kVocabSize); - float total_entropy = 0.0f; - for (size_t pos = 0; pos < max_tokens && pos < prompt.size(); ++pos) { - if (verbosity >= 4) { - LogTopK(tokenizer, logits.data(), activations.logits.data(), kVocabSize, - 10); - } - const int token = prompt[pos]; - const float prob = activations.logits[token]; - if (verbosity >= 3) { - printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token, - TokenString(tokenizer, token).c_str(), prob, - -std::log(prob) / std::log(2.0)); - } - total_entropy -= std::max(std::log(prob), -64.0f); - if (verbosity >= 2 && pos % 100 == 99) { - printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, - total_entropy / std::log(2.0) / (pos + 1)); - } - Transformer(token, pos, weights, activations, kv_cache, pool, - /*layers_output=*/nullptr); - MatVec( - weights.embedder_input_embedding, 0, activations.x.data(), - activations.even_odd.data(), activations.logits.data(), pool); - LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); - memcpy(logits.data(), activations.logits.data(), - kVocabSize * sizeof(logits[0])); - Softmax(activations.logits.data(), kVocabSize); - } - cross_entropy = total_entropy / std::log(2.0); -} - } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); @@ -970,20 +904,5 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } -float Gemma::ComputeCrossEntropy(size_t max_tokens, - const std::vector& prompt, - KVCache& kv_cache, int verbosity) { - pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); - - float cross_entropy = 0.0f; - GEMMA_EXPORT_AND_DISPATCH_MODEL( - model_type_, ComputeCrossEntropy, - (weights_u8_, decode_u8_, tokenizer_, max_tokens, prompt, kv_cache, pool_, - verbosity, cross_entropy)); - - pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); - return cross_entropy; -} - } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma/gemma.h b/gemma/gemma.h index e9ffc08..8cf581a 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -74,6 +74,9 @@ using StreamFunc = std::function; // AcceptFunc is called with token. It should return False for tokens you don't // want to generate and True for tokens you want to generate. using AcceptFunc = std::function; +// CustomSampleFunc is called with the probability distribution for the next +// token, and its return value is used as the next generated token. +using CustomSampleFunc = std::function; struct RuntimeConfig { size_t max_tokens; @@ -83,6 +86,7 @@ struct RuntimeConfig { std::mt19937* gen; const StreamFunc& stream_token; const AcceptFunc& accept_token; + const CustomSampleFunc* sample_func = nullptr; int eos_id = EOS_ID; }; @@ -109,6 +113,7 @@ class Gemma { Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool); ~Gemma(); + Model ModelType() const { return model_type_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ByteStorageT& Weights() const { return weights_u8_; } const ByteStorageT& Prefill() const { return prefill_u8_; } @@ -121,9 +126,6 @@ class Gemma { KVCache& kv_cache, TimingInfo& timing_info, LayersOutputT* layers_output = nullptr); - float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, - KVCache& kv_cache, int verbosity); - private: hwy::ThreadPool& pool_; diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index 75686b9..46dbeb5 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -24,6 +24,7 @@ #include "compression/io.h" // Path #include "gemma/common.h" +#include "gemma/cross_entropy.h" #include "gemma/ops.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/tests/test_util-inl.h" @@ -77,9 +78,9 @@ class GemmaTest : public ::testing::Test { float GemmaCrossEntropy(const std::string& prompt_string) { std::vector prompt; HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt)); - return model.ComputeCrossEntropy(/*max_tokens=*/3072, prompt, kv_cache, - /*verbosity=*/0) / - prompt_string.size(); + return ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, kv_cache, + /*verbosity=*/0) / + prompt_string.size(); } void TestQuestions(const char* kQA[][2], size_t num_questions) {