Merge pull request #217 from szabadka:cross-entropy

PiperOrigin-RevId: 641241133
This commit is contained in:
Copybara-Service 2024-06-07 07:17:35 -07:00
commit 24db2ff725
9 changed files with 183 additions and 101 deletions

View File

@ -107,6 +107,19 @@ cc_library(
],
)
cc_library(
name = "cross_entropy",
srcs = [
"gemma/cross_entropy.cc",
],
hdrs = [
"gemma/cross_entropy.h",
],
deps = [
":gemma_lib",
],
)
cc_library(
name = "args",
hdrs = ["util/args.h"],
@ -141,6 +154,7 @@ cc_test(
],
deps = [
":args",
":cross_entropy",
":gemma_lib",
":ops",
"@googletest//:gtest_main",
@ -190,6 +204,7 @@ cc_binary(
":app",
":args",
":common",
":cross_entropy",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",

View File

@ -61,6 +61,8 @@ set(SOURCES
gemma/activations.h
gemma/common.cc
gemma/common.h
gemma/cross_entropy.cc
gemma/cross_entropy.h
gemma/gemma.cc
gemma/gemma.h
gemma/ops.h

View File

@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
};
RuntimeConfig runtime = {
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
stream_token, accept_token, nullptr, ReverseSequenceSampler::kEndToken,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);

View File

@ -10,6 +10,7 @@
#include <vector>
#include "compression/io.h" // Path
#include "gemma/cross_entropy.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
@ -204,13 +205,15 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice,
float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice,
kv_cache, app.verbosity);
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice;
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
total_input_len += text_slice.size();
printf("Total cross entropy: %f [cumulative: %f]\n",
entropy, total_entropy);
printf("Cross entropy per byte: %f [cumulative: %f]\n",
entropy / text_slice.size(), total_entropy / total_input_len);
}

109
gemma/cross_entropy.cc Normal file
View File

@ -0,0 +1,109 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/cross_entropy.h"
#include <algorithm>
#include <cmath>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
namespace gcpp {
namespace {
template <typename TConfig>
struct GetVocabSize {
int operator()() const { return TConfig::kVocabSize; }
};
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
std::string token_str;
tokenizer.Decode({token}, &token_str);
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
}
void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
size_t k) {
std::vector<std::pair<float, int>> sorted(len);
for (size_t i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
}
std::sort(sorted.begin(), sorted.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
if (a.first != b.first) {
return a.first > b.first;
}
return a.second < b.second;
});
for (size_t i = 0; i < k; ++i) {
printf(" [#%-2d token %6d = %-12s %.2e]\n", static_cast<int>(i + 1),
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
sorted[i].first);
}
}
} // namespace
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity) {
auto stream_token = [](int, float) { return true; };
auto accept_token = [](int) { return true; };
const int vocab_size = CallFunctorForModel<GetVocabSize>(gemma.ModelType());
float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1;
std::function<int(const float*, size_t)> sample_token =
[&](const float* probs, size_t vocab_size) -> int {
const int token = prompt[pos];
const float prob = probs[token];
cross_entropy -= std::max(std::log(prob), -64.0f);
if (verbosity >= 4) {
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
}
if (verbosity >= 3) {
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos,
token, TokenString(gemma.Tokenizer(), token).c_str(), prob,
-std::log(prob) / std::log(2.0));
}
if (verbosity >= 2 && pos % 100 == 99) {
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
cross_entropy / std::log(2.0) / (pos + 1));
}
++pos;
return token;
};
std::vector<int> prompt0 = { prompt[0] };
RuntimeConfig runtime = {
.max_tokens = max_tokens,
.max_generated_tokens = max_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,
.stream_token = stream_token,
.accept_token = accept_token,
.sample_func = &sample_token,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr);
const float scale = 1.0 / std::log(2.0);
return cross_entropy * scale;
}
} // namespace gcpp

31
gemma/cross_entropy.h Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
#include <vector>
#include "gemma/gemma.h"
namespace gcpp {
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_

View File

@ -40,7 +40,6 @@
#include <array>
#include <cmath>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
@ -217,32 +216,6 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
return impl_->Decode(ids, detokenized);
}
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
std::string token_str;
tokenizer.Decode({token}, &token_str);
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
}
void LogTopK(const GemmaTokenizer& tokenizer, float* HWY_RESTRICT logits,
float* HWY_RESTRICT dist, size_t len, size_t k) {
std::vector<std::pair<float, int>> sorted(len);
for (size_t i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
}
std::sort(sorted.begin(), sorted.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
if (a.first != b.first) {
return a.first > b.first;
}
return a.second < b.second;
});
for (size_t i = 0; i < k; ++i) {
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast<int>(i + 1),
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
sorted[i].first, logits[sorted[i].second]);
}
}
} // namespace gcpp
#endif // GEMMA_ONCE
@ -837,14 +810,20 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
MatVec<kVocabSize, TConfig::kModelDim>(
weights.embedder_input_embedding, 0, final_activation,
activations.even_odd.data(), activations.logits.data(), pool);
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize);
if (runtime_config.sample_func) {
token = (*runtime_config.sample_func)(activations.logits.data(),
kVocabSize);
} else {
token = SampleTopK<TConfig::kTopK>(
activations.logits.data(), kVocabSize, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
if (!runtime_config.stream_token(token, activations.logits[token])) {
token = runtime_config.eos_id;
}
}
if (generate_pos == 0) {
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
}
@ -868,51 +847,6 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
}
}
template <class TConfig>
void ComputeCrossEntropy(const ByteStorageT& weights_u8,
ByteStorageT& decode_u8,
const GemmaTokenizer& tokenizer, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity,
float& cross_entropy) {
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& activations = GetActivations<TConfig, 1>(decode_u8);
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
std::vector<float> logits(kVocabSize);
Softmax(activations.logits.data(), kVocabSize);
float total_entropy = 0.0f;
for (size_t pos = 0; pos < max_tokens && pos < prompt.size(); ++pos) {
if (verbosity >= 4) {
LogTopK(tokenizer, logits.data(), activations.logits.data(), kVocabSize,
10);
}
const int token = prompt[pos];
const float prob = activations.logits[token];
if (verbosity >= 3) {
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
TokenString(tokenizer, token).c_str(), prob,
-std::log(prob) / std::log(2.0));
}
total_entropy -= std::max(std::log(prob), -64.0f);
if (verbosity >= 2 && pos % 100 == 99) {
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
total_entropy / std::log(2.0) / (pos + 1));
}
Transformer(token, pos, weights, activations, kv_cache, pool,
/*layers_output=*/nullptr);
MatVec<kVocabSize, kModelDim>(
weights.embedder_input_embedding, 0, activations.x.data(),
activations.even_odd.data(), activations.logits.data(), pool);
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
memcpy(logits.data(), activations.logits.data(),
kVocabSize * sizeof(logits[0]));
Softmax(activations.logits.data(), kVocabSize);
}
cross_entropy = total_entropy / std::log(2.0);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
@ -970,20 +904,5 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
float Gemma::ComputeCrossEntropy(size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
float cross_entropy = 0.0f;
GEMMA_EXPORT_AND_DISPATCH_MODEL(
model_type_, ComputeCrossEntropy,
(weights_u8_, decode_u8_, tokenizer_, max_tokens, prompt, kv_cache, pool_,
verbosity, cross_entropy));
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
return cross_entropy;
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -74,6 +74,9 @@ using StreamFunc = std::function<bool(int, float)>;
// AcceptFunc is called with token. It should return False for tokens you don't
// want to generate and True for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>;
// CustomSampleFunc is called with the probability distribution for the next
// token, and its return value is used as the next generated token.
using CustomSampleFunc = std::function<int(const float*, size_t)>;
struct RuntimeConfig {
size_t max_tokens;
@ -83,6 +86,7 @@ struct RuntimeConfig {
std::mt19937* gen;
const StreamFunc& stream_token;
const AcceptFunc& accept_token;
const CustomSampleFunc* sample_func = nullptr;
int eos_id = EOS_ID;
};
@ -109,6 +113,7 @@ class Gemma {
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
~Gemma();
Model ModelType() const { return model_type_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; }
const ByteStorageT& Prefill() const { return prefill_u8_; }
@ -121,9 +126,6 @@ class Gemma {
KVCache& kv_cache, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity);
private:
hwy::ThreadPool& pool_;

View File

@ -24,6 +24,7 @@
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/cross_entropy.h"
#include "gemma/ops.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h"
@ -77,7 +78,7 @@ class GemmaTest : public ::testing::Test {
float GemmaCrossEntropy(const std::string& prompt_string) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
return model.ComputeCrossEntropy(/*max_tokens=*/3072, prompt, kv_cache,
return ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, kv_cache,
/*verbosity=*/0) /
prompt_string.size();
}