mirror of https://github.com/google/gemma.cpp.git
Merge pull request #217 from szabadka:cross-entropy
PiperOrigin-RevId: 641241133
This commit is contained in:
commit
24db2ff725
15
BUILD.bazel
15
BUILD.bazel
|
|
@ -107,6 +107,19 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "cross_entropy",
|
||||||
|
srcs = [
|
||||||
|
"gemma/cross_entropy.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"gemma/cross_entropy.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":gemma_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "args",
|
name = "args",
|
||||||
hdrs = ["util/args.h"],
|
hdrs = ["util/args.h"],
|
||||||
|
|
@ -141,6 +154,7 @@ cc_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
":args",
|
||||||
|
":cross_entropy",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
|
|
@ -190,6 +204,7 @@ cc_binary(
|
||||||
":app",
|
":app",
|
||||||
":args",
|
":args",
|
||||||
":common",
|
":common",
|
||||||
|
":cross_entropy",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,8 @@ set(SOURCES
|
||||||
gemma/activations.h
|
gemma/activations.h
|
||||||
gemma/common.cc
|
gemma/common.cc
|
||||||
gemma/common.h
|
gemma/common.h
|
||||||
|
gemma/cross_entropy.cc
|
||||||
|
gemma/cross_entropy.h
|
||||||
gemma/gemma.cc
|
gemma/gemma.cc
|
||||||
gemma/gemma.h
|
gemma/gemma.h
|
||||||
gemma/ops.h
|
gemma/ops.h
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
};
|
};
|
||||||
RuntimeConfig runtime = {
|
RuntimeConfig runtime = {
|
||||||
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
|
max_tokens, max_generated_tokens, temperature, verbosity, &gen,
|
||||||
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
|
stream_token, accept_token, nullptr, ReverseSequenceSampler::kEndToken,
|
||||||
};
|
};
|
||||||
TimingInfo timing_info;
|
TimingInfo timing_info;
|
||||||
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
#include "gemma/cross_entropy.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
|
|
@ -204,13 +205,15 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
prompt.begin() + pos + num_tokens);
|
prompt.begin() + pos + num_tokens);
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
|
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
|
||||||
float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice,
|
float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice,
|
||||||
kv_cache, app.verbosity);
|
kv_cache, app.verbosity);
|
||||||
total_entropy += entropy;
|
total_entropy += entropy;
|
||||||
LogSpeedStats(time_start, pos + num_tokens);
|
LogSpeedStats(time_start, pos + num_tokens);
|
||||||
std::string text_slice;
|
std::string text_slice;
|
||||||
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
|
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
|
||||||
total_input_len += text_slice.size();
|
total_input_len += text_slice.size();
|
||||||
|
printf("Total cross entropy: %f [cumulative: %f]\n",
|
||||||
|
entropy, total_entropy);
|
||||||
printf("Cross entropy per byte: %f [cumulative: %f]\n",
|
printf("Cross entropy per byte: %f [cumulative: %f]\n",
|
||||||
entropy / text_slice.size(), total_entropy / total_input_len);
|
entropy / text_slice.size(), total_entropy / total_input_len);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "gemma/cross_entropy.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <regex> // NOLINT
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename TConfig>
|
||||||
|
struct GetVocabSize {
|
||||||
|
int operator()() const { return TConfig::kVocabSize; }
|
||||||
|
};
|
||||||
|
|
||||||
|
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
|
||||||
|
std::string token_str;
|
||||||
|
tokenizer.Decode({token}, &token_str);
|
||||||
|
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogTopK(const GemmaTokenizer& tokenizer, const float* dist, size_t len,
|
||||||
|
size_t k) {
|
||||||
|
std::vector<std::pair<float, int>> sorted(len);
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
|
||||||
|
}
|
||||||
|
std::sort(sorted.begin(), sorted.end(),
|
||||||
|
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||||
|
if (a.first != b.first) {
|
||||||
|
return a.first > b.first;
|
||||||
|
}
|
||||||
|
return a.second < b.second;
|
||||||
|
});
|
||||||
|
for (size_t i = 0; i < k; ++i) {
|
||||||
|
printf(" [#%-2d token %6d = %-12s %.2e]\n", static_cast<int>(i + 1),
|
||||||
|
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
|
||||||
|
sorted[i].first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
|
const std::vector<int>& prompt,
|
||||||
|
KVCache& kv_cache, int verbosity) {
|
||||||
|
auto stream_token = [](int, float) { return true; };
|
||||||
|
auto accept_token = [](int) { return true; };
|
||||||
|
|
||||||
|
const int vocab_size = CallFunctorForModel<GetVocabSize>(gemma.ModelType());
|
||||||
|
float cross_entropy = std::log(vocab_size); // first token
|
||||||
|
size_t pos = 1;
|
||||||
|
std::function<int(const float*, size_t)> sample_token =
|
||||||
|
[&](const float* probs, size_t vocab_size) -> int {
|
||||||
|
const int token = prompt[pos];
|
||||||
|
const float prob = probs[token];
|
||||||
|
cross_entropy -= std::max(std::log(prob), -64.0f);
|
||||||
|
|
||||||
|
if (verbosity >= 4) {
|
||||||
|
LogTopK(gemma.Tokenizer(), probs, vocab_size, 10);
|
||||||
|
}
|
||||||
|
if (verbosity >= 3) {
|
||||||
|
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos,
|
||||||
|
token, TokenString(gemma.Tokenizer(), token).c_str(), prob,
|
||||||
|
-std::log(prob) / std::log(2.0));
|
||||||
|
}
|
||||||
|
if (verbosity >= 2 && pos % 100 == 99) {
|
||||||
|
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||||
|
cross_entropy / std::log(2.0) / (pos + 1));
|
||||||
|
}
|
||||||
|
++pos;
|
||||||
|
return token;
|
||||||
|
};
|
||||||
|
std::vector<int> prompt0 = { prompt[0] };
|
||||||
|
RuntimeConfig runtime = {
|
||||||
|
.max_tokens = max_tokens,
|
||||||
|
.max_generated_tokens = max_tokens - 1,
|
||||||
|
.temperature = 0.0f,
|
||||||
|
.verbosity = verbosity,
|
||||||
|
.gen = nullptr,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.accept_token = accept_token,
|
||||||
|
.sample_func = &sample_token,
|
||||||
|
};
|
||||||
|
TimingInfo timing_info;
|
||||||
|
|
||||||
|
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info, nullptr);
|
||||||
|
|
||||||
|
const float scale = 1.0 / std::log(2.0);
|
||||||
|
return cross_entropy * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
int verbosity);
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||||
103
gemma/gemma.cc
103
gemma/gemma.cc
|
|
@ -40,7 +40,6 @@
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <regex> // NOLINT
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -217,32 +216,6 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
||||||
return impl_->Decode(ids, detokenized);
|
return impl_->Decode(ids, detokenized);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
|
|
||||||
std::string token_str;
|
|
||||||
tokenizer.Decode({token}, &token_str);
|
|
||||||
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
|
||||||
}
|
|
||||||
|
|
||||||
void LogTopK(const GemmaTokenizer& tokenizer, float* HWY_RESTRICT logits,
|
|
||||||
float* HWY_RESTRICT dist, size_t len, size_t k) {
|
|
||||||
std::vector<std::pair<float, int>> sorted(len);
|
|
||||||
for (size_t i = 0; i < len; ++i) {
|
|
||||||
sorted[i] = std::make_pair(dist[i], static_cast<int>(i));
|
|
||||||
}
|
|
||||||
std::sort(sorted.begin(), sorted.end(),
|
|
||||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
|
||||||
if (a.first != b.first) {
|
|
||||||
return a.first > b.first;
|
|
||||||
}
|
|
||||||
return a.second < b.second;
|
|
||||||
});
|
|
||||||
for (size_t i = 0; i < k; ++i) {
|
|
||||||
printf(" [#%-2d token %6d = %-12s %.2e %f]\n", static_cast<int>(i + 1),
|
|
||||||
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
|
|
||||||
sorted[i].first, logits[sorted[i].second]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // GEMMA_ONCE
|
#endif // GEMMA_ONCE
|
||||||
|
|
||||||
|
|
@ -837,13 +810,19 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||||
weights.embedder_input_embedding, 0, final_activation,
|
weights.embedder_input_embedding, 0, final_activation,
|
||||||
activations.even_odd.data(), activations.logits.data(), pool);
|
activations.even_odd.data(), activations.logits.data(), pool);
|
||||||
|
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
||||||
// Barrier: must have all logits so we can subtract max.
|
// Barrier: must have all logits so we can subtract max.
|
||||||
Softmax(activations.logits.data(), kVocabSize);
|
Softmax(activations.logits.data(), kVocabSize);
|
||||||
token = SampleTopK<TConfig::kTopK>(
|
if (runtime_config.sample_func) {
|
||||||
activations.logits.data(), kVocabSize, *runtime_config.gen,
|
token = (*runtime_config.sample_func)(activations.logits.data(),
|
||||||
runtime_config.temperature, runtime_config.accept_token);
|
kVocabSize);
|
||||||
if (!runtime_config.stream_token(token, activations.logits[token])) {
|
} else {
|
||||||
token = runtime_config.eos_id;
|
token = SampleTopK<TConfig::kTopK>(
|
||||||
|
activations.logits.data(), kVocabSize, *runtime_config.gen,
|
||||||
|
runtime_config.temperature, runtime_config.accept_token);
|
||||||
|
if (!runtime_config.stream_token(token, activations.logits[token])) {
|
||||||
|
token = runtime_config.eos_id;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (generate_pos == 0) {
|
if (generate_pos == 0) {
|
||||||
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
|
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
|
||||||
|
|
@ -868,51 +847,6 @@ void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
|
||||||
void ComputeCrossEntropy(const ByteStorageT& weights_u8,
|
|
||||||
ByteStorageT& decode_u8,
|
|
||||||
const GemmaTokenizer& tokenizer, size_t max_tokens,
|
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
|
||||||
hwy::ThreadPool& pool, int verbosity,
|
|
||||||
float& cross_entropy) {
|
|
||||||
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8);
|
|
||||||
auto& activations = GetActivations<TConfig, 1>(decode_u8);
|
|
||||||
|
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
|
||||||
std::vector<float> logits(kVocabSize);
|
|
||||||
Softmax(activations.logits.data(), kVocabSize);
|
|
||||||
float total_entropy = 0.0f;
|
|
||||||
for (size_t pos = 0; pos < max_tokens && pos < prompt.size(); ++pos) {
|
|
||||||
if (verbosity >= 4) {
|
|
||||||
LogTopK(tokenizer, logits.data(), activations.logits.data(), kVocabSize,
|
|
||||||
10);
|
|
||||||
}
|
|
||||||
const int token = prompt[pos];
|
|
||||||
const float prob = activations.logits[token];
|
|
||||||
if (verbosity >= 3) {
|
|
||||||
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
|
|
||||||
TokenString(tokenizer, token).c_str(), prob,
|
|
||||||
-std::log(prob) / std::log(2.0));
|
|
||||||
}
|
|
||||||
total_entropy -= std::max(std::log(prob), -64.0f);
|
|
||||||
if (verbosity >= 2 && pos % 100 == 99) {
|
|
||||||
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
|
||||||
total_entropy / std::log(2.0) / (pos + 1));
|
|
||||||
}
|
|
||||||
Transformer(token, pos, weights, activations, kv_cache, pool,
|
|
||||||
/*layers_output=*/nullptr);
|
|
||||||
MatVec<kVocabSize, kModelDim>(
|
|
||||||
weights.embedder_input_embedding, 0, activations.x.data(),
|
|
||||||
activations.even_odd.data(), activations.logits.data(), pool);
|
|
||||||
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
|
||||||
memcpy(logits.data(), activations.logits.data(),
|
|
||||||
kVocabSize * sizeof(logits[0]));
|
|
||||||
Softmax(activations.logits.data(), kVocabSize);
|
|
||||||
}
|
|
||||||
cross_entropy = total_entropy / std::log(2.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
HWY_AFTER_NAMESPACE();
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
@ -970,20 +904,5 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
float Gemma::ComputeCrossEntropy(size_t max_tokens,
|
|
||||||
const std::vector<int>& prompt,
|
|
||||||
KVCache& kv_cache, int verbosity) {
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
|
||||||
|
|
||||||
float cross_entropy = 0.0f;
|
|
||||||
GEMMA_EXPORT_AND_DISPATCH_MODEL(
|
|
||||||
model_type_, ComputeCrossEntropy,
|
|
||||||
(weights_u8_, decode_u8_, tokenizer_, max_tokens, prompt, kv_cache, pool_,
|
|
||||||
verbosity, cross_entropy));
|
|
||||||
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
|
||||||
return cross_entropy;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,9 @@ using StreamFunc = std::function<bool(int, float)>;
|
||||||
// AcceptFunc is called with token. It should return False for tokens you don't
|
// AcceptFunc is called with token. It should return False for tokens you don't
|
||||||
// want to generate and True for tokens you want to generate.
|
// want to generate and True for tokens you want to generate.
|
||||||
using AcceptFunc = std::function<bool(int)>;
|
using AcceptFunc = std::function<bool(int)>;
|
||||||
|
// CustomSampleFunc is called with the probability distribution for the next
|
||||||
|
// token, and its return value is used as the next generated token.
|
||||||
|
using CustomSampleFunc = std::function<int(const float*, size_t)>;
|
||||||
|
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
|
|
@ -83,6 +86,7 @@ struct RuntimeConfig {
|
||||||
std::mt19937* gen;
|
std::mt19937* gen;
|
||||||
const StreamFunc& stream_token;
|
const StreamFunc& stream_token;
|
||||||
const AcceptFunc& accept_token;
|
const AcceptFunc& accept_token;
|
||||||
|
const CustomSampleFunc* sample_func = nullptr;
|
||||||
int eos_id = EOS_ID;
|
int eos_id = EOS_ID;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -109,6 +113,7 @@ class Gemma {
|
||||||
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
|
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
|
Model ModelType() const { return model_type_; }
|
||||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||||
const ByteStorageT& Prefill() const { return prefill_u8_; }
|
const ByteStorageT& Prefill() const { return prefill_u8_; }
|
||||||
|
|
@ -121,9 +126,6 @@ class Gemma {
|
||||||
KVCache& kv_cache, TimingInfo& timing_info,
|
KVCache& kv_cache, TimingInfo& timing_info,
|
||||||
LayersOutputT* layers_output = nullptr);
|
LayersOutputT* layers_output = nullptr);
|
||||||
|
|
||||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
|
||||||
KVCache& kv_cache, int verbosity);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
hwy::ThreadPool& pool_;
|
hwy::ThreadPool& pool_;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/cross_entropy.h"
|
||||||
#include "gemma/ops.h"
|
#include "gemma/ops.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
@ -77,9 +78,9 @@ class GemmaTest : public ::testing::Test {
|
||||||
float GemmaCrossEntropy(const std::string& prompt_string) {
|
float GemmaCrossEntropy(const std::string& prompt_string) {
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
|
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
|
||||||
return model.ComputeCrossEntropy(/*max_tokens=*/3072, prompt, kv_cache,
|
return ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, kv_cache,
|
||||||
/*verbosity=*/0) /
|
/*verbosity=*/0) /
|
||||||
prompt_string.size();
|
prompt_string.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
void TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue