mirror of https://github.com/google/gemma.cpp.git
154 lines
4.9 KiB
C++
154 lines
4.9 KiB
C++
// Copyright 2024 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
|
#ifndef HWY_DISABLED_TARGETS
|
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
|
#endif // HWY_DISABLED_TARGETS
|
|
|
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
|
// which we pass the filename via macro 'argument'.
|
|
// clang-format off
|
|
#undef HWY_TARGET_INCLUDE
|
|
#define HWY_TARGET_INCLUDE "evals/cross_entropy.cc" // NOLINT
|
|
// clang-format on
|
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
|
#include "hwy/highway.h"
|
|
// After highway.h
|
|
#include "ops/ops-inl.h" // Softmax
|
|
|
|
#ifndef GEMMA_CROSS_ENTROPY_ONCE
|
|
#define GEMMA_CROSS_ENTROPY_ONCE
|
|
|
|
#include <stddef.h>
|
|
#include <stdio.h>
|
|
|
|
#include <algorithm> // std::sort
|
|
#include <cmath>
|
|
#include <regex> // NOLINT
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "evals/cross_entropy.h"
|
|
#include "gemma/gemma.h"
|
|
#include "hwy/base.h"
|
|
|
|
namespace gcpp {
|
|
|
|
namespace {
|
|
|
|
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
|
|
std::string token_str;
|
|
tokenizer.Decode({token}, &token_str);
|
|
return "'" + std::regex_replace(token_str, std::regex("\n"), "\\n") + "'";
|
|
}
|
|
|
|
void LogTopK(const GemmaTokenizer& tokenizer, Logits logits, size_t k) {
|
|
std::vector<std::pair<float, int>> sorted(logits.size());
|
|
for (size_t i = 0; i < logits.size(); ++i) {
|
|
sorted[i] = std::make_pair(logits[i], static_cast<int>(i));
|
|
}
|
|
std::sort(sorted.begin(), sorted.end(),
|
|
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
|
if (a.first != b.first) {
|
|
return a.first > b.first;
|
|
}
|
|
return a.second < b.second;
|
|
});
|
|
for (size_t i = 0; i < k; ++i) {
|
|
printf(" [#%-2d token %6d = %-12s %.2e]\n", static_cast<int>(i + 1),
|
|
sorted[i].second, TokenString(tokenizer, sorted[i].second).c_str(),
|
|
sorted[i].first);
|
|
}
|
|
}
|
|
} // namespace
|
|
} // namespace gcpp
|
|
#endif // GEMMA_CROSS_ENTROPY_ONCE
|
|
|
|
// SIMD code, compiled once per target.
|
|
HWY_BEFORE_NAMESPACE();
|
|
namespace gcpp {
|
|
namespace HWY_NAMESPACE {
|
|
|
|
void CallSoftmax(Logits logits, ThreadingContext& ctx) {
|
|
Softmax(logits, ctx, hwy::Profiler::GlobalIdx());
|
|
}
|
|
|
|
} // namespace HWY_NAMESPACE
|
|
} // namespace gcpp
|
|
HWY_AFTER_NAMESPACE();
|
|
|
|
#if HWY_ONCE
|
|
namespace gcpp {
|
|
|
|
HWY_EXPORT(CallSoftmax);
|
|
|
|
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
|
MatMulEnv& env, int verbosity) {
|
|
const BatchStreamFunc stream_token = [](size_t, size_t, int, float) {
|
|
return true;
|
|
};
|
|
|
|
const int vocab_size = gemma.Config().vocab_size;
|
|
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
|
|
|
const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits,
|
|
size_t /*worker*/) -> TokenAndProb {
|
|
// input is logits, not yet probabilities
|
|
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx);
|
|
// We are called for each token, but pos starts at 1. Clamping
|
|
// max_generated_tokens to prompt.size() should prevent overrun.
|
|
HWY_ASSERT(pos < prompt.size());
|
|
const int token = prompt[pos];
|
|
const float prob = logits[token];
|
|
cross_entropy -= std::max(std::log(prob), -64.0f);
|
|
|
|
if (verbosity >= 4) {
|
|
LogTopK(gemma.Tokenizer(), logits, 10);
|
|
}
|
|
if (verbosity >= 3) {
|
|
printf("pos %4zu token %6d = %-12s %.10e %14.10f bits\n", pos, token,
|
|
TokenString(gemma.Tokenizer(), token).c_str(), prob,
|
|
-std::log(prob) / std::log(2.0));
|
|
}
|
|
if (verbosity >= 2 && pos % 100 == 99) {
|
|
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
|
cross_entropy / std::log(2.0) / (pos + 1));
|
|
}
|
|
return TokenAndProb{.token = token, .prob = prob};
|
|
};
|
|
|
|
std::vector<int> prompt0 = { prompt[0] };
|
|
max_generated_tokens = HWY_MIN(max_generated_tokens, prompt.size());
|
|
RuntimeConfig runtime = {
|
|
.max_generated_tokens = max_generated_tokens - 1,
|
|
.temperature = 0.0f,
|
|
.verbosity = verbosity,
|
|
.batch_stream_token = stream_token,
|
|
.sample_func = sample_token,
|
|
};
|
|
TimingInfo timing_info;
|
|
|
|
gemma.Generate(runtime, prompt0, 0, kv_cache, env, timing_info);
|
|
|
|
const float scale = 1.0f / std::log(2.0f);
|
|
return cross_entropy * scale;
|
|
}
|
|
|
|
} // namespace gcpp
|
|
#endif // HWY_ONCE
|