Add MMLU eval to github

PiperOrigin-RevId: 635495178
This commit is contained in:
Apoorv Reddy 2024-05-20 10:20:22 -07:00 committed by Copybara-Service
parent cfce314715
commit 7f4b85d00b
6 changed files with 2136 additions and 3 deletions

View File

@ -174,3 +174,17 @@ cc_binary(
"@nlohmann_json//:json",
],
)
cc_binary(
name = "gemma_mmlu",
srcs = ["gemma/run_mmlu.cc"],
deps = [
":app",
":gemma_lib",
# "//base",
"@hwy//:hwy",
"@hwy//:profiler",
"@hwy//:thread_pool",
"@nlohmann_json//:json",
],
)

View File

@ -9,11 +9,11 @@
#include "compression/io.h"
#include "gemma/gemma.h"
#include "nlohmann/json.hpp"
#include "util/app.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "nlohmann/json.hpp"
using json = nlohmann::json;

View File

@ -8,7 +8,6 @@
#include <utility> // std::pair
#include <vector>
#include "nlohmann/json.hpp"
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
@ -16,6 +15,7 @@
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/timer.h"
#include "nlohmann/json.hpp"
using json = nlohmann::json;

1912
gemma/evals/mmlu.json Normal file

File diff suppressed because one or more lines are too long

View File

@ -61,8 +61,11 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
// probability is 0.0f. StreamFunc should return False to stop generation and
// True to continue generation.
using StreamFunc = std::function<bool(int, float)>;
// AcceptFunc is called with token. It should return False for tokens you don't
// want to generate and True for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>;
struct RuntimeConfig {

204
gemma/run_mmlu.cc Normal file
View File

@ -0,0 +1,204 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line text interface to gemma.
#include <ctime>
#include <fstream>
#include <iostream>
#include <random>
#include <set>
#include <sstream>
#include <string>
#include <vector>
// Placeholder for internal header, do not modify.
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "nlohmann/json.hpp"
namespace gcpp {
void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity,
std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
// token index within the current turn
int max_tokens = 4096;
std::mt19937 gen;
if (args.deterministic) {
gen.seed(42);
} else {
std::random_device rd;
gen.seed(rd());
}
float answers = 0.0;
float correct_answers = 0.0;
std::ifstream fJson("/tmp/mmlu.json");
std::stringstream buffer;
buffer << fJson.rdbuf();
auto json = nlohmann::json::parse(buffer.str());
std::vector<std::string> accept_tokens = {"A", "B", "C", "D"};
std::set<int> accept_token_set{};
for (const std::string& accept_token : accept_tokens) {
std::vector<int> accept_token_ids;
HWY_ASSERT(model.Tokenizer()->Encode(accept_token, &accept_token_ids));
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
}
for (auto sample : json["samples"]) {
int abs_pos = 0; // absolute token index over all turns
int current_pos = 0;
int prompt_size{};
// cout << "prompt:" << sample["prompt"] << endl;
const std::string& prompt_string = sample["prompt"];
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
prompt_size = prompt.size();
const std::string& correct_answer = accept_tokens[sample["input_label"]];
// max_tokens = prompt_size + max_tokens;
std::vector<int> predicted_token_ids;
predicted_token_ids.reserve(max_tokens);
auto stream_token = [&current_pos, &prompt_size, &predicted_token_ids,
&accept_token_set](int token, float proba) {
++current_pos;
if (current_pos > prompt_size) {
predicted_token_ids.push_back(token);
// If the generated token is in the accepted token set, return False.
// This will stop further generation.
return accept_token_set.find(token) == accept_token_set.end();
}
return true;
};
auto accept_token = [&current_pos, &prompt_size,
&accept_token_set](int token) {
// i.e. we have no constraints on accepted tokens
if (accept_token_set.empty()) {
return true;
}
if (current_pos >= prompt_size) {
return accept_token_set.find(token) != accept_token_set.end();
} else {
// auto-accept early tokens
return true;
}
};
gcpp::TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = verbosity,
.gen = &gen,
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool,
timing_info);
std::string output_string;
HWY_ASSERT(model.Tokenizer()->Decode(predicted_token_ids, &output_string));
std::cout << "QuestionId: " << sample["i"] << "; "
<< "Predicted Answer: " << output_string << "; "
<< "Correct Answer: " << correct_answer << std::endl;
answers += 1.0;
if (output_string == correct_answer) {
correct_answers += 1.0;
}
std::cout << "Running accuracy = " << "["
<< static_cast<int>(correct_answers) << "/"
<< static_cast<int>(answers) << "]" << " = "
<< correct_answers / answers << std::endl;
}
}
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc");
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
if (app.num_threads > 10) {
PinThreadToCore(app.num_threads - 1); // Main thread
pool.Run(0, pool.NumThreads(),
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
}
} // namespace gcpp
int main(int argc, char** argv) {
{
PROFILER_ZONE("Startup.misc");
// Placeholder for internal init, do not modify.
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (const char* error = loader.Validate()) {
fprintf(stderr,
"\ngemma.cpp\n---------\n\nTo run gemma.cpp, you need to "
"specify 3 required model loading arguments: --tokenizer, "
"--compressed_weights, "
"and --model.\n\nModel Loading Arguments\n\n");
loader.Help();
fprintf(stderr, "\nInference Arguments\n\n");
inference.Help();
fprintf(stderr, "\nApplication Arguments\n\n");
app.Help();
fprintf(stderr, "\n\n");
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(loader, inference, app);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;
}