From 465998d25ac29a9b8fdffc39b30d189a705d8ed9 Mon Sep 17 00:00:00 2001 From: Zoltan Szabadka Date: Fri, 7 Jun 2024 11:45:07 +0000 Subject: [PATCH] Add support for custom sampling function to runtime config. With this addition the ComputeCrossEntropy function can be moved to its own library, because now we can compute it using only the public API functions from gemma.h --- BUILD.bazel | 15 ++++++ CMakeLists.txt | 2 + backprop/optimize_test.cc | 2 +- gemma/benchmark.cc | 7 ++- gemma/cross_entropy.cc | 109 ++++++++++++++++++++++++++++++++++++++ gemma/cross_entropy.h | 31 +++++++++++ gemma/gemma.cc | 103 ++++------------------------------- gemma/gemma.h | 8 +-- gemma/gemma_test.cc | 7 +-- 9 files changed, 183 insertions(+), 101 deletions(-) create mode 100644 gemma/cross_entropy.cc create mode 100644 gemma/cross_entropy.h diff --git a/BUILD.bazel b/BUILD.bazel index 81e6a46..b3fd70d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -107,6 +107,19 @@ cc_library( ], ) +cc_library( + name = "cross_entropy", + srcs = [ + "gemma/cross_entropy.cc", + ], + hdrs = [ + "gemma/cross_entropy.h", + ], + deps = [ + ":gemma_lib", + ], +) + cc_library( name = "args", hdrs = ["util/args.h"], @@ -141,6 +154,7 @@ cc_test( ], deps = [ ":args", + ":cross_entropy", ":gemma_lib", ":ops", "@googletest//:gtest_main", @@ -190,6 +204,7 @@ cc_binary( ":app", ":args", ":common", + ":cross_entropy", ":gemma_lib", "//compression:io", "@hwy//:hwy", diff --git a/CMakeLists.txt b/CMakeLists.txt index 72409f6..6a24ba0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,8 @@ set(SOURCES gemma/activations.h gemma/common.cc gemma/common.h + gemma/cross_entropy.cc + gemma/cross_entropy.h gemma/gemma.cc gemma/gemma.h gemma/ops.h diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 0d89416..3698ceb 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) { }; RuntimeConfig runtime = { max_tokens, max_generated_tokens, temperature, verbosity, &gen, - stream_token, accept_token, ReverseSequenceSampler::kEndToken, + stream_token, accept_token, nullptr, ReverseSequenceSampler::kEndToken, }; TimingInfo timing_info; gemma.Generate(runtime, prompt, 0, kv_cache, timing_info); diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index 1d41153..2f0b132 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -10,6 +10,7 @@ #include #include "compression/io.h" // Path +#include "gemma/cross_entropy.h" #include "gemma/gemma.h" #include "util/app.h" #include "util/args.h" @@ -204,13 +205,15 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type); - float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice, - kv_cache, app.verbosity); + float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice, + kv_cache, app.verbosity); total_entropy += entropy; LogSpeedStats(time_start, pos + num_tokens); std::string text_slice; HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice)); total_input_len += text_slice.size(); + printf("Total cross entropy: %f [cumulative: %f]\n", + entropy, total_entropy); printf("Cross entropy per byte: %f [cumulative: %f]\n", entropy / text_slice.size(), total_entropy / total_input_len); } diff --git a/gemma/cross_entropy.cc b/gemma/cross_entropy.cc new file mode 100644 index 0000000..453e8e7 --- /dev/null +++ b/gemma/cross_entropy.cc @@ -0,0 +1,109 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/cross_entropy.h" + +#include +#include +#include // NOLINT +#include +#include +#include + +namespace gcpp { + +namespace { +template +struct GetVocabSize { + int operator()() const { return TConfig::kVocabSize; } +}; + +static std::string TokenString(const GemmaTokenizer& tokenizer, int token) { + std::string token_str; + tokenizer.Decode({token}, &token_str); + return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'"; +} + +void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len, + size_t k) { + std::vector> sorted(len); + for (size_t i = 0; i < len; ++i) { + sorted[i] = std::make_pair(dist[i], static_cast(i)); + } + std::sort(sorted.begin(), sorted.end(), + [](const std::pair& a, const std::pair& b) { + if (a.first != b.first) { + return a.first > b.first; + } + return a.second < b.second; + }); + for (size_t i = 0; i < k; ++i) { + printf(" [#%-2d token %6d = %-12s %.2e]\n", static_cast(i + 1), + sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(), + sorted[i].first); + } +} +} // namespace + +float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, + const std::vector& prompt, + KVCache& kv_cache, int verbosity) { + auto stream_token = [](int, float) { return true; }; + auto accept_token = [](int) { return true; }; + + const int vocab_size = CallFunctorForModel(gemma.ModelType()); + float cross_entropy = std::log(vocab_size); // first token + size_t pos = 1; + std::function sample_token = + [&](const float* probs, size_t vocab_size) -> int { + const int token = prompt[pos]; + const float prob = probs[token]; + cross_entropy -= std::max(std::log(prob), -64.0f); + + if (verbosity >= 4) { + LogTopK(gemma.Tokenizer(), probs, vocab_size, 10); + } + if (verbosity >= 3) { + printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, + token, TokenString(gemma.Tokenizer(), token).c_str(), prob, + -std::log(prob) / std::log(2.0)); + } + if (verbosity >= 2 && pos % 100 == 99) { + printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, + cross_entropy / std::log(2.0) / (pos + 1)); + } + ++pos; + return token; + }; + std::vector prompt0 = { prompt[0] }; + RuntimeConfig runtime = { + .max_tokens = max_tokens, + .max_generated_tokens = max_tokens - 1, + .temperature = 0.0f, + .verbosity = verbosity, + .gen = nullptr, + .stream_token = stream_token, + .accept_token = accept_token, + .sample_func = &sample_token, + }; + TimingInfo timing_info; + + gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr); + + const float scale = 1.0 / std::log(2.0); + return cross_entropy * scale; +} + +} // namespace gcpp diff --git a/gemma/cross_entropy.h b/gemma/cross_entropy.h new file mode 100644 index 0000000..e6bac00 --- /dev/null +++ b/gemma/cross_entropy.h @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ + +#include + +#include "gemma/gemma.h" + +namespace gcpp { + +float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, + const std::vector& prompt, KVCache& kv_cache, + int verbosity); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 23d7922..b6a4c21 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -40,7 +40,6 @@ #include #include #include -#include // NOLINT #include #include #include @@ -217,32 +216,6 @@ bool GemmaTokenizer::Decode(const std::vector& ids, return impl_->Decode(ids, detokenized); } -static std::string TokenString(const GemmaTokenizer& tokenizer, int token) { - std::string token_str; - tokenizer.Decode({token}, &token_str); - return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'"; -} - -void LogTopK(const GemmaTokenizer& tokenizer, float* HWY_RESTRICT logits, - float* HWY_RESTRICT dist, size_t len, size_t k) { - std::vector> sorted(len); - for (size_t i = 0; i < len; ++i) { - sorted[i] = std::make_pair(dist[i], static_cast(i)); - } - std::sort(sorted.begin(), sorted.end(), - [](const std::pair& a, const std::pair& b) { - if (a.first != b.first) { - return a.first > b.first; - } - return a.second < b.second; - }); - for (size_t i = 0; i < k; ++i) { - printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast(i + 1), - sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(), - sorted[i].first, logits[sorted[i].second]); - } -} - } // namespace gcpp #endif // GEMMA_ONCE @@ -837,13 +810,19 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, MatVec( weights.embedder_input_embedding, 0, final_activation, activations.even_odd.data(), activations.logits.data(), pool); + LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); - token = SampleTopK( - activations.logits.data(), kVocabSize, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token); - if (!runtime_config.stream_token(token, activations.logits[token])) { - token = runtime_config.eos_id; + if (runtime_config.sample_func) { + token = (*runtime_config.sample_func)(activations.logits.data(), + kVocabSize); + } else { + token = SampleTopK( + activations.logits.data(), kVocabSize, *runtime_config.gen, + runtime_config.temperature, runtime_config.accept_token); + if (!runtime_config.stream_token(token, activations.logits[token])) { + token = runtime_config.eos_id; + } } if (generate_pos == 0) { timing_info.time_to_first_token = hwy::platform::Now() - gen_start; @@ -868,51 +847,6 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, } } -template -void ComputeCrossEntropy(const ByteStorageT& weights_u8, - ByteStorageT& decode_u8, - const GemmaTokenizer& tokenizer, size_t max_tokens, - const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, int verbosity, - float& cross_entropy) { - const WeightsT& weights = GetWeights(weights_u8); - auto& activations = GetActivations(decode_u8); - - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kVocabSize = TConfig::kVocabSize; - std::vector logits(kVocabSize); - Softmax(activations.logits.data(), kVocabSize); - float total_entropy = 0.0f; - for (size_t pos = 0; pos < max_tokens && pos < prompt.size(); ++pos) { - if (verbosity >= 4) { - LogTopK(tokenizer, logits.data(), activations.logits.data(), kVocabSize, - 10); - } - const int token = prompt[pos]; - const float prob = activations.logits[token]; - if (verbosity >= 3) { - printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token, - TokenString(tokenizer, token).c_str(), prob, - -std::log(prob) / std::log(2.0)); - } - total_entropy -= std::max(std::log(prob), -64.0f); - if (verbosity >= 2 && pos % 100 == 99) { - printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, - total_entropy / std::log(2.0) / (pos + 1)); - } - Transformer(token, pos, weights, activations, kv_cache, pool, - /*layers_output=*/nullptr); - MatVec( - weights.embedder_input_embedding, 0, activations.x.data(), - activations.even_odd.data(), activations.logits.data(), pool); - LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); - memcpy(logits.data(), activations.logits.data(), - kVocabSize * sizeof(logits[0])); - Softmax(activations.logits.data(), kVocabSize); - } - cross_entropy = total_entropy / std::log(2.0); -} - } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); @@ -970,20 +904,5 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } -float Gemma::ComputeCrossEntropy(size_t max_tokens, - const std::vector& prompt, - KVCache& kv_cache, int verbosity) { - pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); - - float cross_entropy = 0.0f; - GEMMA_EXPORT_AND_DISPATCH_MODEL( - model_type_, ComputeCrossEntropy, - (weights_u8_, decode_u8_, tokenizer_, max_tokens, prompt, kv_cache, pool_, - verbosity, cross_entropy)); - - pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); - return cross_entropy; -} - } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma/gemma.h b/gemma/gemma.h index e9fcb2f..7da8cdd 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -72,6 +72,9 @@ using StreamFunc = std::function; // AcceptFunc is called with token. It should return False for tokens you don't // want to generate and True for tokens you want to generate. using AcceptFunc = std::function; +// CustomSampleFunc is called with the probability distribution for the next +// token, and its return value is used as the next generated token. +using CustomSampleFunc = std::function; struct RuntimeConfig { size_t max_tokens; @@ -81,6 +84,7 @@ struct RuntimeConfig { std::mt19937* gen; const StreamFunc& stream_token; const AcceptFunc& accept_token; + const CustomSampleFunc* sample_func = nullptr; int eos_id = EOS_ID; }; @@ -107,6 +111,7 @@ class Gemma { Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool); ~Gemma(); + Model ModelType() const { return model_type_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ByteStorageT& Weights() const { return weights_u8_; } const ByteStorageT& Prefill() const { return prefill_u8_; } @@ -119,9 +124,6 @@ class Gemma { KVCache& kv_cache, TimingInfo& timing_info, LayersOutputT* layers_output = nullptr); - float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, - KVCache& kv_cache, int verbosity); - private: hwy::ThreadPool& pool_; diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index 75686b9..46dbeb5 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -24,6 +24,7 @@ #include "compression/io.h" // Path #include "gemma/common.h" +#include "gemma/cross_entropy.h" #include "gemma/ops.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/tests/test_util-inl.h" @@ -77,9 +78,9 @@ class GemmaTest : public ::testing::Test { float GemmaCrossEntropy(const std::string& prompt_string) { std::vector prompt; HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt)); - return model.ComputeCrossEntropy(/*max_tokens=*/3072, prompt, kv_cache, - /*verbosity=*/0) / - prompt_string.size(); + return ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, kv_cache, + /*verbosity=*/0) / + prompt_string.size(); } void TestQuestions(const char* kQA[][2], size_t num_questions) {