Add support for custom sampling function to runtime config.

With this addition the ComputeCrossEntropy function can be moved
to its own library, because now we can compute it using only the
public API functions from gemma.h
This commit is contained in:
Zoltan Szabadka 2024-06-07 11:45:07 +00:00
parent f7ac7092d6
commit 465998d25a
9 changed files with 183 additions and 101 deletions

View File

@ -107,6 +107,19 @@ cc_library(
],
)
cc_library(
name = "cross_entropy",
srcs = [
"gemma/cross_entropy.cc",
],
hdrs = [
"gemma/cross_entropy.h",
],
deps = [
":gemma_lib",
],
)
cc_library(
name = "args",
hdrs = ["util/args.h"],
@ -141,6 +154,7 @@ cc_test(
],
deps = [
":args",
":cross_entropy",
":gemma_lib",
":ops",
"@googletest//:gtest_main",
@ -190,6 +204,7 @@ cc_binary(
":app",
":args",
":common",
":cross_entropy",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",

View File

@ -61,6 +61,8 @@ set(SOURCES
gemma/activations.h
gemma/common.cc
gemma/common.h
gemma/cross_entropy.cc
gemma/cross_entropy.h
gemma/gemma.cc
gemma/gemma.h
gemma/ops.h

View File

@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
};
RuntimeConfig runtime = {
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
stream_token, accept_token, nullptr, ReverseSequenceSampler::kEndToken,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);

View File

@ -10,6 +10,7 @@
#include <vector>
#include "compression/io.h" // Path
#include "gemma/cross_entropy.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
@ -204,13 +205,15 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice,
float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice,
kv_cache, app.verbosity);
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice;
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
total_input_len += text_slice.size();
printf("Total cross entropy: %f [cumulative: %f]\n",
entropy, total_entropy);
printf("Cross entropy per byte: %f [cumulative: %f]\n",
entropy / text_slice.size(), total_entropy / total_input_len);
}

109
gemma/cross_entropy.cc Normal file
View File

@ -0,0 +1,109 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/cross_entropy.h"
#include <algorithm>
#include <cmath>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
namespace gcpp {
namespace {
template <typename TConfig>
struct GetVocabSize {
int operator()() const { return TConfig::kVocabSize; }
};
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
std::string token_str;
tokenizer.Decode({token}, &token_str);
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
}
void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
size_t k) {
std::vector<std::pair<float, int>> sorted(len);
for (size_t i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
}
std::sort(sorted.begin(), sorted.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
if (a.first != b.first) {
return a.first > b.first;
}
return a.second < b.second;
});
for (size_t i = 0; i < k; ++i) {
printf(" [#%-2d token %6d = %-12s %.2e]\n", static_cast<int>(i + 1),
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
sorted[i].first);
}
}
} // namespace
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity) {
auto stream_token = [](int, float) { return true; };
auto accept_token = [](int) { return true; };
const int vocab_size = CallFunctorForModel<GetVocabSize>(gemma.ModelType());
float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1;
std::function<int(const float*, size_t)> sample_token =
[&](const float* probs, size_t vocab_size) -> int {
const int token = prompt[pos];
const float prob = probs[token];
cross_entropy -= std::max(std::log(prob), -64.0f);
if (verbosity >= 4) {
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
}
if (verbosity >= 3) {
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos,
token, TokenString(gemma.Tokenizer(), token).c_str(), prob,
-std::log(prob) / std::log(2.0));
}
if (verbosity >= 2 && pos % 100 == 99) {
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
cross_entropy / std::log(2.0) / (pos + 1));
}
++pos;
return token;
};
std::vector<int> prompt0 = { prompt[0] };
RuntimeConfig runtime = {
.max_tokens = max_tokens,
.max_generated_tokens = max_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,
.stream_token = stream_token,
.accept_token = accept_token,
.sample_func = &sample_token,
};
TimingInfo timing_info;
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr);
const float scale = 1.0 / std::log(2.0);
return cross_entropy * scale;
}
} // namespace gcpp

31
gemma/cross_entropy.h Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
#include <vector>
#include "gemma/gemma.h"
namespace gcpp {
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_

View File

@ -40,7 +40,6 @@
#include <array>
#include <cmath>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
@ -217,32 +216,6 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
return impl_->Decode(ids, detokenized);
}
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
std::string token_str;
tokenizer.Decode({token}, &token_str);
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
}
void LogTopK(const GemmaTokenizer& tokenizer, float* HWY_RESTRICT logits,
float* HWY_RESTRICT dist, size_t len, size_t k) {
std::vector<std::pair<float, int>> sorted(len);
for (size_t i = 0; i < len; ++i) {
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
}
std::sort(sorted.begin(), sorted.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
if (a.first != b.first) {
return a.first > b.first;
}
return a.second < b.second;
});
for (size_t i = 0; i < k; ++i) {
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast<int>(i + 1),
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
sorted[i].first, logits[sorted[i].second]);
}
}
} // namespace gcpp
#endif // GEMMA_ONCE
@ -837,14 +810,20 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
MatVec<kVocabSize, TConfig::kModelDim>(
weights.embedder_input_embedding, 0, final_activation,
activations.even_odd.data(), activations.logits.data(), pool);
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize);
if (runtime_config.sample_func) {
token = (*runtime_config.sample_func)(activations.logits.data(),
kVocabSize);
} else {
token = SampleTopK<TConfig::kTopK>(
activations.logits.data(), kVocabSize, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
if (!runtime_config.stream_token(token, activations.logits[token])) {
token = runtime_config.eos_id;
}
}
if (generate_pos == 0) {
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
}
@ -868,51 +847,6 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
}
}
template <class TConfig>
void ComputeCrossEntropy(const ByteStorageT& weights_u8,
ByteStorageT& decode_u8,
const GemmaTokenizer& tokenizer, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity,
float& cross_entropy) {
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& activations = GetActivations<TConfig, 1>(decode_u8);
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
std::vector<float> logits(kVocabSize);
Softmax(activations.logits.data(), kVocabSize);
float total_entropy = 0.0f;
for (size_t pos = 0; pos < max_tokens && pos < prompt.size(); ++pos) {
if (verbosity >= 4) {
LogTopK(tokenizer, logits.data(), activations.logits.data(), kVocabSize,
10);
}
const int token = prompt[pos];
const float prob = activations.logits[token];
if (verbosity >= 3) {
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
TokenString(tokenizer, token).c_str(), prob,
-std::log(prob) / std::log(2.0));
}
total_entropy -= std::max(std::log(prob), -64.0f);
if (verbosity >= 2 && pos % 100 == 99) {
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
total_entropy / std::log(2.0) / (pos + 1));
}
Transformer(token, pos, weights, activations, kv_cache, pool,
/*layers_output=*/nullptr);
MatVec<kVocabSize, kModelDim>(
weights.embedder_input_embedding, 0, activations.x.data(),
activations.even_odd.data(), activations.logits.data(), pool);
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
memcpy(logits.data(), activations.logits.data(),
kVocabSize * sizeof(logits[0]));
Softmax(activations.logits.data(), kVocabSize);
}
cross_entropy = total_entropy / std::log(2.0);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
@ -970,20 +904,5 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
float Gemma::ComputeCrossEntropy(size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
float cross_entropy = 0.0f;
GEMMA_EXPORT_AND_DISPATCH_MODEL(
model_type_, ComputeCrossEntropy,
(weights_u8_, decode_u8_, tokenizer_, max_tokens, prompt, kv_cache, pool_,
verbosity, cross_entropy));
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
return cross_entropy;
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -72,6 +72,9 @@ using StreamFunc = std::function<bool(int, float)>;
// AcceptFunc is called with token. It should return False for tokens you don't
// want to generate and True for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>;
// CustomSampleFunc is called with the probability distribution for the next
// token, and its return value is used as the next generated token.
using CustomSampleFunc = std::function<int(const float*, size_t)>;
struct RuntimeConfig {
size_t max_tokens;
@ -81,6 +84,7 @@ struct RuntimeConfig {
std::mt19937* gen;
const StreamFunc& stream_token;
const AcceptFunc& accept_token;
const CustomSampleFunc* sample_func = nullptr;
int eos_id = EOS_ID;
};
@ -107,6 +111,7 @@ class Gemma {
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
~Gemma();
Model ModelType() const { return model_type_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; }
const ByteStorageT& Prefill() const { return prefill_u8_; }
@ -119,9 +124,6 @@ class Gemma {
KVCache& kv_cache, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity);
private:
hwy::ThreadPool& pool_;

View File

@ -24,6 +24,7 @@
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/cross_entropy.h"
#include "gemma/ops.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h"
@ -77,7 +78,7 @@ class GemmaTest : public ::testing::Test {
float GemmaCrossEntropy(const std::string& prompt_string) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
return model.ComputeCrossEntropy(/*max_tokens=*/3072, prompt, kv_cache,
return ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, kv_cache,
/*verbosity=*/0) /
prompt_string.size();
}