mirror of https://github.com/google/gemma.cpp.git
162 lines
5.1 KiB
C++
162 lines
5.1 KiB
C++
// Copyright 2024 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include <stdio.h>
|
|
|
|
#include <algorithm>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "evals/benchmark_helper.h"
|
|
#include "gemma/gemma.h" // Gemma
|
|
#include "io/io.h" // Path
|
|
#include "util/args.h"
|
|
#include "hwy/base.h"
|
|
#include "hwy/highway.h"
|
|
#include "hwy/profiler.h"
|
|
#include "nlohmann/json.hpp"
|
|
|
|
namespace gcpp {
|
|
|
|
struct JsonArgs : public ArgsBase<JsonArgs> {
|
|
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
|
|
|
Path input;
|
|
|
|
// Returns error string or nullptr if OK.
|
|
const char* Validate() const {
|
|
if (input.Empty()) return "Must specify --input";
|
|
if (!input.Exists()) return "--input file does not exist";
|
|
return nullptr;
|
|
}
|
|
|
|
template <class Visitor>
|
|
void ForEach(const Visitor& visitor) {
|
|
visitor(input, "input", Path(), "Full pathname of mmlu.json.");
|
|
};
|
|
};
|
|
|
|
// Linear search for a few tokens is faster than std::set.
|
|
// TODO: instead of accepting for each vocab entry, filter the logits once.
|
|
class TokenSet {
|
|
public:
|
|
TokenSet(const GemmaTokenizer& tokenizer,
|
|
const std::vector<std::string>& strings) {
|
|
all_tokens_.reserve(strings.size());
|
|
for (const std::string& str : strings) {
|
|
std::vector<int> tokens;
|
|
fprintf(stderr, "%s -> ", str.c_str());
|
|
HWY_ASSERT(tokenizer.Encode(str, &tokens));
|
|
for (int token : tokens) {
|
|
fprintf(stderr, "%d, ", token);
|
|
all_tokens_.push_back(token);
|
|
}
|
|
fprintf(stderr, "\n");
|
|
}
|
|
}
|
|
|
|
bool Contains(int token) const {
|
|
return std::find(all_tokens_.begin(), all_tokens_.end(), token) !=
|
|
all_tokens_.end();
|
|
}
|
|
|
|
private:
|
|
std::vector<int> all_tokens_;
|
|
};
|
|
|
|
void Run(GemmaEnv& env, JsonArgs& json) {
|
|
PROFILER_ZONE("Run.all");
|
|
|
|
float answers = 0.0f;
|
|
float correct_answers = 0.0f;
|
|
|
|
auto json_data = nlohmann::json::parse(ReadFileToString(json.input));
|
|
|
|
const std::vector<std::string> accept_strings = {
|
|
"A", "B", "C", "D", //
|
|
" A", " B", " C", " D", //
|
|
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
|
|
const TokenSet accept_set(env.GetGemma()->Tokenizer(), accept_strings);
|
|
|
|
for (auto sample : json_data["samples"]) {
|
|
const int id = sample["i"];
|
|
fprintf(stderr, "Processing question %d\n", id);
|
|
const std::string& correct_answer = accept_strings[sample["input_label"]];
|
|
std::string prompt_string = sample["prompt"];
|
|
// AcceptFunc restricts the output to one of these four tokens, so make an
|
|
// effort to steer the model towards that. See
|
|
// https://huggingface.co/blog/open-llm-leaderboard-mmlu
|
|
prompt_string +=
|
|
"What is start of the line with the correct answer? "
|
|
"Do not include any justifications or explanations. Reply only with a "
|
|
"letter.";
|
|
const std::vector<int> prompt = env.WrapAndTokenize(prompt_string);
|
|
const size_t prompt_size = prompt.size();
|
|
|
|
std::vector<int> predicted_token_ids;
|
|
predicted_token_ids.reserve(4096);
|
|
size_t generated = 0;
|
|
const StreamFunc stream_token = [&generated, prompt_size,
|
|
&predicted_token_ids](int token,
|
|
float proba) {
|
|
PROFILER_ZONE("Stream");
|
|
++generated;
|
|
if (generated > prompt_size) {
|
|
predicted_token_ids.push_back(token);
|
|
}
|
|
return true;
|
|
};
|
|
|
|
// Although " A" is a token, it is difficult to associate that with the
|
|
// correct answer. Only accepting certain tokens is risky: (A) is easily
|
|
// confused with the word "A".
|
|
gcpp::TimingInfo timing_info;
|
|
gcpp::RuntimeConfig runtime_config = {
|
|
.max_generated_tokens = 30,
|
|
.temperature = 0.0f,
|
|
.verbosity = env.Verbosity(),
|
|
.stream_token = stream_token,
|
|
};
|
|
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
|
|
env.MutableKVCache(), env.MutableEnv(),
|
|
timing_info);
|
|
|
|
std::string output_string = env.StringFromTokens(predicted_token_ids);
|
|
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),
|
|
output_string.c_str());
|
|
|
|
answers += 1.0f;
|
|
if (output_string == correct_answer) {
|
|
correct_answers += 1.0f;
|
|
}
|
|
fprintf(stderr, "%.0f/%.0f = %.2f%%\n", correct_answers, answers,
|
|
correct_answers / answers);
|
|
}
|
|
}
|
|
|
|
} // namespace gcpp
|
|
|
|
int main(int argc, char** argv) {
|
|
{
|
|
PROFILER_ZONE("Startup.all");
|
|
gcpp::GemmaEnv env(argc, argv);
|
|
gcpp::JsonArgs json(argc, argv);
|
|
gcpp::AbortIfInvalidArgs(json);
|
|
gcpp::Run(env, json);
|
|
}
|
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
|
return 0;
|
|
}
|