From 03284d752e4e996dc4928eba45f44a93154fce65 Mon Sep 17 00:00:00 2001 From: Andrey Mikhaylov Date: Fri, 12 Apr 2024 11:22:42 +0000 Subject: [PATCH] Added layers output functionality to gemma and a binary debug_output to save the outputs to a json file. --- BUILD.bazel | 19 +++++++ CMakeLists.txt | 3 ++ debug_prompt.cc | 135 ++++++++++++++++++++++++++++++++++++++++++++++++ gemma/gemma.cc | 66 ++++++++++++++--------- gemma/gemma.h | 3 +- 5 files changed, 201 insertions(+), 25 deletions(-) create mode 100644 debug_prompt.cc diff --git a/BUILD.bazel b/BUILD.bazel index 68a2026..c1ca793 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -158,3 +158,22 @@ cc_binary( "@nlohmann_json//:json", ], ) + +cc_binary( + name = "debug_prompt", + srcs = [ + "debug_prompt.cc", + ], + deps = [ + ":app", + ":args", + ":gemma_lib", + # "//base", + "//compression:compress", + "@hwy//:hwy", + "@hwy//:nanobenchmark", + "@hwy//:profiler", + "@hwy//:thread_pool", + "@nlohmann_json//:json", + ], +) \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 203b6c1..f1c7104 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,6 +88,9 @@ target_link_libraries(gemma libgemma hwy hwy_contrib) add_executable(benchmark gemma/benchmark.cc) target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) +add_executable(debug_prompt debug_prompt.cc) +target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json) + ## Tests set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests") if (GEMMA_ENABLE_TESTS) diff --git a/debug_prompt.cc b/debug_prompt.cc new file mode 100644 index 0000000..ae4f52d --- /dev/null +++ b/debug_prompt.cc @@ -0,0 +1,135 @@ + +#include +#include +#include + +#include "gemma/gemma.h" +#include "nlohmann/json.hpp" +// copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" + +using json = nlohmann::json; + +class PromptArgs : public gcpp::ArgsBase { + public: + PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + + gcpp::Path layers_output; + std::string prompt; + + template + void ForEach(const Visitor& visitor) { + visitor(layers_output.path, "layers_output", std::string(""), + "Path to store layers output", 2); + visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2); + } +}; + +std::pair QueryModel( + gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, + gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, + const std::string& input, gcpp::LayersOutputT* layers_output) { + std::vector prompt; + HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); + + // For both pre-trained and instruction-tuned models: prepend "" token + // if needed. + prompt.insert(prompt.begin(), 2); + std::string res; + size_t total_tokens = 0; + auto accept_token = [](int) { return true; }; + std::mt19937 gen; + gen.seed(42); + + auto stream_token = [&res, &total_tokens, &app, + tokenizer = model.Tokenizer()](int token, float) { + ++total_tokens; + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text)); + res += token_text; + return true; + }; + if (app.verbosity >= 2) { + std::cout << args.max_tokens << " " << args.max_generated_tokens << " " + << args.temperature; + } + GenerateGemma(model, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, + inner_pool, stream_token, accept_token, gen, app.verbosity, + layers_output); + return {res, total_tokens}; +} + +class OutputJsonLogger { + public: + json json_output; + + gcpp::LayersOutputT layers_output_log_f = + [this](int pos, std::string key, const float* values, size_t values_len) { + std::vector v{values, values + values_len}; + json_output[std::to_string(pos)][key] = v; + }; +}; + +/* Run this in the same way as gemma, p.ex.: + ./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \ + 2b-it-sfp.sbs --prompt "..." --layers_output [path] +*/ +int main(int argc, char** argv) { + gcpp::LoaderArgs loader(argc, argv); + gcpp::InferenceArgs args(argc, argv); // inference + gcpp::AppArgs app(argc, argv); + PromptArgs prompt_args(argc, argv); + + if (const char* error = loader.Validate()) { + HWY_ABORT("\nInvalid loader args: %s", error); + } + if (const char* error = args.Validate()) { + HWY_ABORT("\nInvalid inference args: %s", error); + } + const bool log_layers_output = !prompt_args.layers_output.path.empty(); + OutputJsonLogger json_logger; + gcpp::LayersOutputT* layers_output = + log_layers_output ? &json_logger.layers_output_log_f : nullptr; + + hwy::ThreadPool inner_pool(0); + hwy::ThreadPool pool(app.num_threads); + // For many-core, pinning threads to cores helps. + if (app.num_threads > 10) { + gcpp::PinThreadToCore(app.num_threads - 1); // Main thread + + pool.Run(0, pool.NumThreads(), [](uint64_t /*task*/, size_t thread) { + gcpp::PinThreadToCore(thread); + }); + } + + gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); + auto kv_cache = CreateKVCache(loader.ModelType()); + + const std::string& prompt = prompt_args.prompt; + if (prompt.empty()) { + std::cout << "Please specify --prompt" << std::endl; + return EXIT_FAILURE; + } + const auto [answer, token_count] = QueryModel( + model, args, app, kv_cache, inner_pool, pool, prompt, layers_output); + std::cout << answer.substr(prompt.size()) << "\n" << std::flush; + + if (log_layers_output) { + std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out); + if (!output_f) { + std::cout << "Opening file failed" << std::endl; + return EXIT_FAILURE; + } + output_f << json_logger.json_output.dump(); + if (!output_f) { + std::cout << "Writing to file failed" << std::endl; + return EXIT_FAILURE; + } + output_f.close(); + } + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/gemma/gemma.cc b/gemma/gemma.cc index b129443..9cfaf1b 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -442,7 +442,7 @@ struct GemmaInterface { hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity) = 0; + int verbosity, LayersOutputT* layers_output) = 0; virtual float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, @@ -535,8 +535,8 @@ struct GemmaImpl : public GemmaInterface { float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937&, - int verbosity) override; + const AcceptFunc& accept_token, std::mt19937&, int verbosity, + LayersOutputT* layers_output) override; float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, @@ -935,7 +935,12 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, template void Transformer(int token, size_t pos, const WeightArrayT& weights, Activations& activations, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) { + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + LayersOutputT* layers_output) { + if (layers_output != nullptr) { + float token_f = token; + (*layers_output)(pos, "Tokens", &token_f, 1); + } static constexpr size_t kModelDim = TConfig::kModelDim; Decompress(weights.embedder_input_embedding, token * kModelDim, activations.x.data(), kModelDim); @@ -943,7 +948,6 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights, GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = EmbeddingScaling(); MulByConst(kEmbScaling, activations.x.data(), kModelDim); - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { auto type = TConfig::kLayerConfig[layer]; const auto* layer_weights = weights.GetLayer(layer); @@ -964,9 +968,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights, activations.bf_pre_ffw_rms_out.data(), kModelDim); FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool); AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim); + if (layers_output != nullptr) { + std::string block_name = "blocks." + std::to_string(layer); + (*layers_output)(pos, block_name, activations.x.data(), kModelDim); + } } RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(), kModelDim); + if (layers_output != nullptr) { + (*layers_output)(pos, "final_norm", activations.x.data(), kModelDim); + } } template @@ -1005,7 +1016,7 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity) { + int verbosity, LayersOutputT* layers_output) { static constexpr size_t kVocabSize = TConfig::kVocabSize; Activations& activations = *gemma.state.get(); Activations& prefill_activations = @@ -1072,12 +1083,14 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, for (size_t generate_pos = 0; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { - Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool); + const bool is_generating_phase = pos_offset >= prompt_size - 1; + Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, + layers_output); float* final_activation = activations.x.data(); // The condition below is always true if we are doing Prefill above. // We keep it here for clarity so that the code is correct even if Prefill // is disabled. - if (pos_offset >= prompt_size - 1) { + if (is_generating_phase) { PROFILER_ZONE("Gen.Embedding"); // Generation phase MatVec(weights.embedder_input_embedding, @@ -1166,7 +1179,8 @@ float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, total_entropy / std::log(2.0) / (pos + 1)); } - Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool); + Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, + nullptr); MatVec(weights.embedder_input_embedding, 0, activations.x.data(), activations.logits.data(), pool); @@ -1186,10 +1200,10 @@ void Generate2B(GemmaImpl& gemma, size_t max_tokens, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity) { + int verbosity, LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + accept_token, gen, verbosity, layers_output); } void Generate7B(GemmaImpl& gemma, size_t max_tokens, @@ -1198,10 +1212,10 @@ void Generate7B(GemmaImpl& gemma, size_t max_tokens, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity) { + int verbosity, LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + accept_token, gen, verbosity, layers_output); } void GenerateGriffin2B(GemmaImpl& gemma, size_t max_tokens, @@ -1211,10 +1225,10 @@ void GenerateGriffin2B(GemmaImpl& gemma, size_t max_tokens, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity) { + int verbosity, LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + accept_token, gen, verbosity, layers_output); } float ComputeCrossEntropy2B(GemmaImpl& gemma, size_t max_tokens, @@ -1478,10 +1492,11 @@ void GemmaImpl::Generate( const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { + std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, + layers_output); } template <> @@ -1490,10 +1505,11 @@ void GemmaImpl::Generate( const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { + std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate7B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, + layers_output); } template <> @@ -1502,10 +1518,11 @@ void GemmaImpl::Generate( const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { + std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, + layers_output); } template <> @@ -1575,11 +1592,11 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity) { + int verbosity, LayersOutputT* layers_output) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + accept_token, gen, verbosity, layers_output); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } @@ -1591,7 +1608,8 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, GenerateGemma( gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, - stream_token, [](int) { return true; }, gen, runtime_config.verbosity); + stream_token, [](int) { return true; }, gen, runtime_config.verbosity, + nullptr); } void CompressWeights(gcpp::Model model, const Path& weights, diff --git a/gemma/gemma.h b/gemma/gemma.h index d2674e1..7128440 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -32,6 +32,7 @@ namespace gcpp { using GemmaWeightT = GEMMA_WEIGHT_T; using EmbedderInputT = hwy::bfloat16_t; +using LayersOutputT = std::function; constexpr size_t kPrefillBatchSize = 16; constexpr bool kSystemPrompt = false; @@ -97,7 +98,7 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity); + int verbosity, LayersOutputT* layers_output = nullptr); // Convenience function for the common case: // - Bundle runtime parameters as RuntimeConfig