mirror of https://github.com/google/gemma.cpp.git
Merge pull request #217 from szabadka:cross-entropy
PiperOrigin-RevId: 641241133
This commit is contained in:
commit
24db2ff725
15
BUILD.bazel
15
BUILD.bazel
|
|
@ -107,6 +107,19 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cross_entropy",
|
||||
srcs = [
|
||||
"gemma/cross_entropy.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/cross_entropy.h",
|
||||
],
|
||||
deps = [
|
||||
":gemma_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
|
|
@ -141,6 +154,7 @@ cc_test(
|
|||
],
|
||||
deps = [
|
||||
":args",
|
||||
":cross_entropy",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
"@googletest//:gtest_main",
|
||||
|
|
@ -190,6 +204,7 @@ cc_binary(
|
|||
":app",
|
||||
":args",
|
||||
":common",
|
||||
":cross_entropy",
|
||||
":gemma_lib",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
|
|
|
|||
|
|
@ -61,6 +61,8 @@ set(SOURCES
|
|||
gemma/activations.h
|
||||
gemma/common.cc
|
||||
gemma/common.h
|
||||
gemma/cross_entropy.cc
|
||||
gemma/cross_entropy.h
|
||||
gemma/gemma.cc
|
||||
gemma/gemma.h
|
||||
gemma/ops.h
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
};
|
||||
RuntimeConfig runtime = {
|
||||
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
|
||||
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
|
||||
stream_token, accept_token, nullptr, ReverseSequenceSampler::kEndToken,
|
||||
};
|
||||
TimingInfo timing_info;
|
||||
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/cross_entropy.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h"
|
||||
|
|
@ -204,13 +205,15 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
|||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
|
||||
float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice,
|
||||
float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice,
|
||||
kv_cache, app.verbosity);
|
||||
total_entropy += entropy;
|
||||
LogSpeedStats(time_start, pos + num_tokens);
|
||||
std::string text_slice;
|
||||
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
|
||||
total_input_len += text_slice.size();
|
||||
printf("Total cross entropy: %f [cumulative: %f]\n",
|
||||
entropy, total_entropy);
|
||||
printf("Cross entropy per byte: %f [cumulative: %f]\n",
|
||||
entropy / text_slice.size(), total_entropy / total_input_len);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/cross_entropy.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
template <typename TConfig>
|
||||
struct GetVocabSize {
|
||||
int operator()() const { return TConfig::kVocabSize; }
|
||||
};
|
||||
|
||||
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
|
||||
std::string token_str;
|
||||
tokenizer.Decode({token}, &token_str);
|
||||
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
||||
}
|
||||
|
||||
void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
|
||||
size_t k) {
|
||||
std::vector<std::pair<float, int>> sorted(len);
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
|
||||
}
|
||||
std::sort(sorted.begin(), sorted.end(),
|
||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
printf(" [#%-2d token %6d = %-12s %.2e]\n", static_cast<int>(i + 1),
|
||||
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
|
||||
sorted[i].first);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, int verbosity) {
|
||||
auto stream_token = [](int, float) { return true; };
|
||||
auto accept_token = [](int) { return true; };
|
||||
|
||||
const int vocab_size = CallFunctorForModel<GetVocabSize>(gemma.ModelType());
|
||||
float cross_entropy = std::log(vocab_size); // first token
|
||||
size_t pos = 1;
|
||||
std::function<int(const float*, size_t)> sample_token =
|
||||
[&](const float* probs, size_t vocab_size) -> int {
|
||||
const int token = prompt[pos];
|
||||
const float prob = probs[token];
|
||||
cross_entropy -= std::max(std::log(prob), -64.0f);
|
||||
|
||||
if (verbosity >= 4) {
|
||||
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
|
||||
}
|
||||
if (verbosity >= 3) {
|
||||
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos,
|
||||
token, TokenString(gemma.Tokenizer(), token).c_str(), prob,
|
||||
-std::log(prob) / std::log(2.0));
|
||||
}
|
||||
if (verbosity >= 2 && pos % 100 == 99) {
|
||||
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||
cross_entropy / std::log(2.0) / (pos + 1));
|
||||
}
|
||||
++pos;
|
||||
return token;
|
||||
};
|
||||
std::vector<int> prompt0 = { prompt[0] };
|
||||
RuntimeConfig runtime = {
|
||||
.max_tokens = max_tokens,
|
||||
.max_generated_tokens = max_tokens - 1,
|
||||
.temperature = 0.0f,
|
||||
.verbosity = verbosity,
|
||||
.gen = nullptr,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
.sample_func = &sample_token,
|
||||
};
|
||||
TimingInfo timing_info;
|
||||
|
||||
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr);
|
||||
|
||||
const float scale = 1.0 / std::log(2.0);
|
||||
return cross_entropy * scale;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
int verbosity);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||
|
|
@ -40,7 +40,6 @@
|
|||
#include <array>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
|
@ -217,32 +216,6 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
|||
return impl_->Decode(ids, detokenized);
|
||||
}
|
||||
|
||||
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
|
||||
std::string token_str;
|
||||
tokenizer.Decode({token}, &token_str);
|
||||
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
||||
}
|
||||
|
||||
void LogTopK(const GemmaTokenizer& tokenizer, float* HWY_RESTRICT logits,
|
||||
float* HWY_RESTRICT dist, size_t len, size_t k) {
|
||||
std::vector<std::pair<float, int>> sorted(len);
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
|
||||
}
|
||||
std::sort(sorted.begin(), sorted.end(),
|
||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast<int>(i + 1),
|
||||
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
|
||||
sorted[i].first, logits[sorted[i].second]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // GEMMA_ONCE
|
||||
|
||||
|
|
@ -837,14 +810,20 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
|||
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||
weights.embedder_input_embedding, 0, final_activation,
|
||||
activations.even_odd.data(), activations.logits.data(), pool);
|
||||
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
||||
// Barrier: must have all logits so we can subtract max.
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
if (runtime_config.sample_func) {
|
||||
token = (*runtime_config.sample_func)(activations.logits.data(),
|
||||
kVocabSize);
|
||||
} else {
|
||||
token = SampleTopK<TConfig::kTopK>(
|
||||
activations.logits.data(), kVocabSize, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token);
|
||||
if (!runtime_config.stream_token(token, activations.logits[token])) {
|
||||
token = runtime_config.eos_id;
|
||||
}
|
||||
}
|
||||
if (generate_pos == 0) {
|
||||
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
|
||||
}
|
||||
|
|
@ -868,51 +847,6 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
|||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void ComputeCrossEntropy(const ByteStorageT& weights_u8,
|
||||
ByteStorageT& decode_u8,
|
||||
const GemmaTokenizer& tokenizer, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, int verbosity,
|
||||
float& cross_entropy) {
|
||||
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8);
|
||||
auto& activations = GetActivations<TConfig, 1>(decode_u8);
|
||||
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
std::vector<float> logits(kVocabSize);
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
float total_entropy = 0.0f;
|
||||
for (size_t pos = 0; pos < max_tokens && pos < prompt.size(); ++pos) {
|
||||
if (verbosity >= 4) {
|
||||
LogTopK(tokenizer, logits.data(), activations.logits.data(), kVocabSize,
|
||||
10);
|
||||
}
|
||||
const int token = prompt[pos];
|
||||
const float prob = activations.logits[token];
|
||||
if (verbosity >= 3) {
|
||||
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
|
||||
TokenString(tokenizer, token).c_str(), prob,
|
||||
-std::log(prob) / std::log(2.0));
|
||||
}
|
||||
total_entropy -= std::max(std::log(prob), -64.0f);
|
||||
if (verbosity >= 2 && pos % 100 == 99) {
|
||||
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||
total_entropy / std::log(2.0) / (pos + 1));
|
||||
}
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool,
|
||||
/*layers_output=*/nullptr);
|
||||
MatVec<kVocabSize, kModelDim>(
|
||||
weights.embedder_input_embedding, 0, activations.x.data(),
|
||||
activations.even_odd.data(), activations.logits.data(), pool);
|
||||
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
||||
memcpy(logits.data(), activations.logits.data(),
|
||||
kVocabSize * sizeof(logits[0]));
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
}
|
||||
cross_entropy = total_entropy / std::log(2.0);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
@ -970,20 +904,5 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
|||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
float Gemma::ComputeCrossEntropy(size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, int verbosity) {
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
|
||||
float cross_entropy = 0.0f;
|
||||
GEMMA_EXPORT_AND_DISPATCH_MODEL(
|
||||
model_type_, ComputeCrossEntropy,
|
||||
(weights_u8_, decode_u8_, tokenizer_, max_tokens, prompt, kv_cache, pool_,
|
||||
verbosity, cross_entropy));
|
||||
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
return cross_entropy;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -74,6 +74,9 @@ using StreamFunc = std::function<bool(int, float)>;
|
|||
// AcceptFunc is called with token. It should return False for tokens you don't
|
||||
// want to generate and True for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int)>;
|
||||
// CustomSampleFunc is called with the probability distribution for the next
|
||||
// token, and its return value is used as the next generated token.
|
||||
using CustomSampleFunc = std::function<int(const float*, size_t)>;
|
||||
|
||||
struct RuntimeConfig {
|
||||
size_t max_tokens;
|
||||
|
|
@ -83,6 +86,7 @@ struct RuntimeConfig {
|
|||
std::mt19937* gen;
|
||||
const StreamFunc& stream_token;
|
||||
const AcceptFunc& accept_token;
|
||||
const CustomSampleFunc* sample_func = nullptr;
|
||||
int eos_id = EOS_ID;
|
||||
};
|
||||
|
||||
|
|
@ -109,6 +113,7 @@ class Gemma {
|
|||
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
|
||||
~Gemma();
|
||||
|
||||
Model ModelType() const { return model_type_; }
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||
const ByteStorageT& Prefill() const { return prefill_u8_; }
|
||||
|
|
@ -121,9 +126,6 @@ class Gemma {
|
|||
KVCache& kv_cache, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output = nullptr);
|
||||
|
||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, int verbosity);
|
||||
|
||||
private:
|
||||
hwy::ThreadPool& pool_;
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/cross_entropy.h"
|
||||
#include "gemma/ops.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
|
@ -77,7 +78,7 @@ class GemmaTest : public ::testing::Test {
|
|||
float GemmaCrossEntropy(const std::string& prompt_string) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
|
||||
return model.ComputeCrossEntropy(/*max_tokens=*/3072, prompt, kv_cache,
|
||||
return ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, kv_cache,
|
||||
/*verbosity=*/0) /
|
||||
prompt_string.size();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue