Added layers output functionality to gemma and a binary debug_output to save the outputs to a json file.

This commit is contained in:
Andrey Mikhaylov 2024-04-12 11:22:42 +00:00
parent 342e998cb6
commit 03284d752e
5 changed files with 201 additions and 25 deletions

View File

@ -158,3 +158,22 @@ cc_binary(
"@nlohmann_json//:json", "@nlohmann_json//:json",
], ],
) )
cc_binary(
name = "debug_prompt",
srcs = [
"debug_prompt.cc",
],
deps = [
":app",
":args",
":gemma_lib",
# "//base",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:thread_pool",
"@nlohmann_json//:json",
],
)

View File

@ -88,6 +88,9 @@ target_link_libraries(gemma libgemma hwy hwy_contrib)
add_executable(benchmark gemma/benchmark.cc) add_executable(benchmark gemma/benchmark.cc)
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
add_executable(debug_prompt debug_prompt.cc)
target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
## Tests ## Tests
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests") set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
if (GEMMA_ENABLE_TESTS) if (GEMMA_ENABLE_TESTS)

135
debug_prompt.cc Normal file
View File

@ -0,0 +1,135 @@
#include <fstream>
#include <iostream>
#include <string>
#include "gemma/gemma.h"
#include "nlohmann/json.hpp"
// copybara:import_next_line:gemma_cpp
#include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
using json = nlohmann::json;
class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Path layers_output;
std::string prompt;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(layers_output.path, "layers_output", std::string(""),
"Path to store layers output", 2);
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
}
};
std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
const std::string& input, gcpp::LayersOutputT* layers_output) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), 2);
std::string res;
size_t total_tokens = 0;
auto accept_token = [](int) { return true; };
std::mt19937 gen;
gen.seed(42);
auto stream_token = [&res, &total_tokens, &app,
tokenizer = model.Tokenizer()](int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
res += token_text;
return true;
};
if (app.verbosity >= 2) {
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
<< args.temperature;
}
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
inner_pool, stream_token, accept_token, gen, app.verbosity,
layers_output);
return {res, total_tokens};
}
class OutputJsonLogger {
public:
json json_output;
gcpp::LayersOutputT layers_output_log_f =
[this](int pos, std::string key, const float* values, size_t values_len) {
std::vector<float> v{values, values + values_len};
json_output[std::to_string(pos)][key] = v;
};
};
/* Run this in the same way as gemma, p.ex.:
./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \
2b-it-sfp.sbs --prompt "..." --layers_output [path]
*/
int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs args(argc, argv); // inference
gcpp::AppArgs app(argc, argv);
PromptArgs prompt_args(argc, argv);
if (const char* error = loader.Validate()) {
HWY_ABORT("\nInvalid loader args: %s", error);
}
if (const char* error = args.Validate()) {
HWY_ABORT("\nInvalid inference args: %s", error);
}
const bool log_layers_output = !prompt_args.layers_output.path.empty();
OutputJsonLogger json_logger;
gcpp::LayersOutputT* layers_output =
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps.
if (app.num_threads > 10) {
gcpp::PinThreadToCore(app.num_threads - 1); // Main thread
pool.Run(0, pool.NumThreads(), [](uint64_t /*task*/, size_t thread) {
gcpp::PinThreadToCore(thread);
});
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
const std::string& prompt = prompt_args.prompt;
if (prompt.empty()) {
std::cout << "Please specify --prompt" << std::endl;
return EXIT_FAILURE;
}
const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, inner_pool, pool, prompt, layers_output);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
if (log_layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
if (!output_f) {
std::cout << "Opening file failed" << std::endl;
return EXIT_FAILURE;
}
output_f << json_logger.json_output.dump();
if (!output_f) {
std::cout << "Writing to file failed" << std::endl;
return EXIT_FAILURE;
}
output_f.close();
}
return EXIT_SUCCESS;
}

View File

@ -442,7 +442,7 @@ struct GemmaInterface {
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) = 0; int verbosity, LayersOutputT* layers_output) = 0;
virtual float ComputeCrossEntropy(size_t max_tokens, virtual float ComputeCrossEntropy(size_t max_tokens,
const std::vector<int>& prompt, const std::vector<int>& prompt,
@ -535,8 +535,8 @@ struct GemmaImpl : public GemmaInterface {
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937&, const AcceptFunc& accept_token, std::mt19937&, int verbosity,
int verbosity) override; LayersOutputT* layers_output) override;
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt, float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
@ -935,7 +935,12 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
template <typename WeightArrayT, class TConfig> template <typename WeightArrayT, class TConfig>
void Transformer(int token, size_t pos, const WeightArrayT& weights, void Transformer(int token, size_t pos, const WeightArrayT& weights,
Activations<TConfig, 1>& activations, KVCache& kv_cache, Activations<TConfig, 1>& activations, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) { hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
LayersOutputT* layers_output) {
if (layers_output != nullptr) {
float token_f = token;
(*layers_output)(pos, "Tokens", &token_f, 1);
}
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
Decompress(weights.embedder_input_embedding, token * kModelDim, Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
@ -943,7 +948,6 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>(); EmbeddingScaling<TConfig>();
MulByConst(kEmbScaling, activations.x.data(), kModelDim); MulByConst(kEmbScaling, activations.x.data(), kModelDim);
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer]; auto type = TConfig::kLayerConfig[layer];
const auto* layer_weights = weights.GetLayer(layer); const auto* layer_weights = weights.GetLayer(layer);
@ -964,9 +968,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
activations.bf_pre_ffw_rms_out.data(), kModelDim); activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool); FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool);
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim); AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
if (layers_output != nullptr) {
std::string block_name = "blocks." + std::to_string(layer);
(*layers_output)(pos, block_name, activations.x.data(), kModelDim);
}
} }
RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(), RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(),
kModelDim); kModelDim);
if (layers_output != nullptr) {
(*layers_output)(pos, "final_norm", activations.x.data(), kModelDim);
}
} }
template <class TConfig> template <class TConfig>
@ -1005,7 +1016,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) { int verbosity, LayersOutputT* layers_output) {
static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kVocabSize = TConfig::kVocabSize;
Activations<TConfig, 1>& activations = *gemma.state.get(); Activations<TConfig, 1>& activations = *gemma.state.get();
Activations<TConfig, kPrefillBatchSize>& prefill_activations = Activations<TConfig, kPrefillBatchSize>& prefill_activations =
@ -1072,12 +1083,14 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
for (size_t generate_pos = 0; for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool); const bool is_generating_phase = pos_offset >= prompt_size - 1;
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
layers_output);
float* final_activation = activations.x.data(); float* final_activation = activations.x.data();
// The condition below is always true if we are doing Prefill above. // The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill // We keep it here for clarity so that the code is correct even if Prefill
// is disabled. // is disabled.
if (pos_offset >= prompt_size - 1) { if (is_generating_phase) {
PROFILER_ZONE("Gen.Embedding"); PROFILER_ZONE("Gen.Embedding");
// Generation phase // Generation phase
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding, MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
@ -1166,7 +1179,8 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
total_entropy / std::log(2.0) / (pos + 1)); total_entropy / std::log(2.0) / (pos + 1));
} }
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool); Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
nullptr);
MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0, MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0,
activations.x.data(), activations.x.data(),
activations.logits.data(), pool); activations.logits.data(), pool);
@ -1186,10 +1200,10 @@ void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) { int verbosity, LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token, start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity); accept_token, gen, verbosity, layers_output);
} }
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens, void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
@ -1198,10 +1212,10 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) { int verbosity, LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token, start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity); accept_token, gen, verbosity, layers_output);
} }
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens, void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
@ -1211,10 +1225,10 @@ void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) { int verbosity, LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token, start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity); accept_token, gen, verbosity, layers_output);
} }
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens, float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
@ -1478,10 +1492,11 @@ void GemmaImpl<ConfigGemma2B>::Generate(
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) { std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate2B) HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
layers_output);
} }
template <> template <>
@ -1490,10 +1505,11 @@ void GemmaImpl<ConfigGemma7B>::Generate(
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) { std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate7B) HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
layers_output);
} }
template <> template <>
@ -1502,10 +1518,11 @@ void GemmaImpl<ConfigGriffin2B>::Generate(
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) { std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
layers_output);
} }
template <> template <>
@ -1575,11 +1592,11 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) { int verbosity, LayersOutputT* layers_output) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin); pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token, start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity); accept_token, gen, verbosity, layers_output);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock); pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }
@ -1591,7 +1608,8 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
GenerateGemma( GenerateGemma(
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
stream_token, [](int) { return true; }, gen, runtime_config.verbosity); stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
nullptr);
} }
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,

View File

@ -32,6 +32,7 @@ namespace gcpp {
using GemmaWeightT = GEMMA_WEIGHT_T; using GemmaWeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t; using EmbedderInputT = hwy::bfloat16_t;
using LayersOutputT = std::function<void(int, std::string, const float*, size_t)>;
constexpr size_t kPrefillBatchSize = 16; constexpr size_t kPrefillBatchSize = 16;
constexpr bool kSystemPrompt = false; constexpr bool kSystemPrompt = false;
@ -97,7 +98,7 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity); int verbosity, LayersOutputT* layers_output = nullptr);
// Convenience function for the common case: // Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig // - Bundle runtime parameters as RuntimeConfig