mirror of https://github.com/google/gemma.cpp.git
Added layers output functionality to gemma and a binary debug_output to save the outputs to a json file.
This commit is contained in:
parent
342e998cb6
commit
03284d752e
19
BUILD.bazel
19
BUILD.bazel
|
|
@ -158,3 +158,22 @@ cc_binary(
|
|||
"@nlohmann_json//:json",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "debug_prompt",
|
||||
srcs = [
|
||||
"debug_prompt.cc",
|
||||
],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":gemma_lib",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
"@nlohmann_json//:json",
|
||||
],
|
||||
)
|
||||
|
|
@ -88,6 +88,9 @@ target_link_libraries(gemma libgemma hwy hwy_contrib)
|
|||
add_executable(benchmark gemma/benchmark.cc)
|
||||
target_link_libraries(benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||
|
||||
add_executable(debug_prompt debug_prompt.cc)
|
||||
target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||
|
||||
## Tests
|
||||
set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests")
|
||||
if (GEMMA_ENABLE_TESTS)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,135 @@
|
|||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/app.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
|
||||
public:
|
||||
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
gcpp::Path layers_output;
|
||||
std::string prompt;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(layers_output.path, "layers_output", std::string(""),
|
||||
"Path to store layers output", 2);
|
||||
visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2);
|
||||
}
|
||||
};
|
||||
|
||||
std::pair<std::string, int> QueryModel(
|
||||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
||||
const std::string& input, gcpp::LayersOutputT* layers_output) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
prompt.insert(prompt.begin(), 2);
|
||||
std::string res;
|
||||
size_t total_tokens = 0;
|
||||
auto accept_token = [](int) { return true; };
|
||||
std::mt19937 gen;
|
||||
gen.seed(42);
|
||||
|
||||
auto stream_token = [&res, &total_tokens, &app,
|
||||
tokenizer = model.Tokenizer()](int token, float) {
|
||||
++total_tokens;
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||
res += token_text;
|
||||
return true;
|
||||
};
|
||||
if (app.verbosity >= 2) {
|
||||
std::cout << args.max_tokens << " " << args.max_generated_tokens << " "
|
||||
<< args.temperature;
|
||||
}
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||
inner_pool, stream_token, accept_token, gen, app.verbosity,
|
||||
layers_output);
|
||||
return {res, total_tokens};
|
||||
}
|
||||
|
||||
class OutputJsonLogger {
|
||||
public:
|
||||
json json_output;
|
||||
|
||||
gcpp::LayersOutputT layers_output_log_f =
|
||||
[this](int pos, std::string key, const float* values, size_t values_len) {
|
||||
std::vector<float> v{values, values + values_len};
|
||||
json_output[std::to_string(pos)][key] = v;
|
||||
};
|
||||
};
|
||||
|
||||
/* Run this in the same way as gemma, p.ex.:
|
||||
./debug_prompt --tokenizer tokenizer.spm --model 2b-it --weights \
|
||||
2b-it-sfp.sbs --prompt "..." --layers_output [path]
|
||||
*/
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::InferenceArgs args(argc, argv); // inference
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
PromptArgs prompt_args(argc, argv);
|
||||
|
||||
if (const char* error = loader.Validate()) {
|
||||
HWY_ABORT("\nInvalid loader args: %s", error);
|
||||
}
|
||||
if (const char* error = args.Validate()) {
|
||||
HWY_ABORT("\nInvalid inference args: %s", error);
|
||||
}
|
||||
const bool log_layers_output = !prompt_args.layers_output.path.empty();
|
||||
OutputJsonLogger json_logger;
|
||||
gcpp::LayersOutputT* layers_output =
|
||||
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
|
||||
|
||||
hwy::ThreadPool inner_pool(0);
|
||||
hwy::ThreadPool pool(app.num_threads);
|
||||
// For many-core, pinning threads to cores helps.
|
||||
if (app.num_threads > 10) {
|
||||
gcpp::PinThreadToCore(app.num_threads - 1); // Main thread
|
||||
|
||||
pool.Run(0, pool.NumThreads(), [](uint64_t /*task*/, size_t thread) {
|
||||
gcpp::PinThreadToCore(thread);
|
||||
});
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
|
||||
const std::string& prompt = prompt_args.prompt;
|
||||
if (prompt.empty()) {
|
||||
std::cout << "Please specify --prompt" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
const auto [answer, token_count] = QueryModel(
|
||||
model, args, app, kv_cache, inner_pool, pool, prompt, layers_output);
|
||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||
|
||||
if (log_layers_output) {
|
||||
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
|
||||
if (!output_f) {
|
||||
std::cout << "Opening file failed" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
output_f << json_logger.json_output.dump();
|
||||
if (!output_f) {
|
||||
std::cout << "Writing to file failed" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
output_f.close();
|
||||
}
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
|
@ -442,7 +442,7 @@ struct GemmaInterface {
|
|||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) = 0;
|
||||
int verbosity, LayersOutputT* layers_output) = 0;
|
||||
|
||||
virtual float ComputeCrossEntropy(size_t max_tokens,
|
||||
const std::vector<int>& prompt,
|
||||
|
|
@ -535,8 +535,8 @@ struct GemmaImpl : public GemmaInterface {
|
|||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937&,
|
||||
int verbosity) override;
|
||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity,
|
||||
LayersOutputT* layers_output) override;
|
||||
|
||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
|
|
@ -935,7 +935,12 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
template <typename WeightArrayT, class TConfig>
|
||||
void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
LayersOutputT* layers_output) {
|
||||
if (layers_output != nullptr) {
|
||||
float token_f = token;
|
||||
(*layers_output)(pos, "Tokens", &token_f, 1);
|
||||
}
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data(), kModelDim);
|
||||
|
|
@ -943,7 +948,6 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
|||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||
EmbeddingScaling<TConfig>();
|
||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
|
|
@ -964,9 +968,16 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
|||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool);
|
||||
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
|
||||
if (layers_output != nullptr) {
|
||||
std::string block_name = "blocks." + std::to_string(layer);
|
||||
(*layers_output)(pos, block_name, activations.x.data(), kModelDim);
|
||||
}
|
||||
}
|
||||
RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(),
|
||||
kModelDim);
|
||||
if (layers_output != nullptr) {
|
||||
(*layers_output)(pos, "final_norm", activations.x.data(), kModelDim);
|
||||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
|
|
@ -1005,7 +1016,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||
Activations<TConfig, kPrefillBatchSize>& prefill_activations =
|
||||
|
|
@ -1072,12 +1083,14 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
for (size_t generate_pos = 0;
|
||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||
++pos, ++pos_offset, ++generate_pos) {
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool);
|
||||
const bool is_generating_phase = pos_offset >= prompt_size - 1;
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
|
||||
layers_output);
|
||||
float* final_activation = activations.x.data();
|
||||
// The condition below is always true if we are doing Prefill above.
|
||||
// We keep it here for clarity so that the code is correct even if Prefill
|
||||
// is disabled.
|
||||
if (pos_offset >= prompt_size - 1) {
|
||||
if (is_generating_phase) {
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Generation phase
|
||||
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
|
||||
|
|
@ -1166,7 +1179,8 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||
total_entropy / std::log(2.0) / (pos + 1));
|
||||
}
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool);
|
||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
|
||||
nullptr);
|
||||
MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0,
|
||||
activations.x.data(),
|
||||
activations.logits.data(), pool);
|
||||
|
|
@ -1186,10 +1200,10 @@ void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
|||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
accept_token, gen, verbosity, layers_output);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
|
|
@ -1198,10 +1212,10 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
|||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
accept_token, gen, verbosity, layers_output);
|
||||
}
|
||||
|
||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
||||
|
|
@ -1211,10 +1225,10 @@ void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
|||
hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
accept_token, gen, verbosity, layers_output);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
|
|
@ -1478,10 +1492,11 @@ void GemmaImpl<ConfigGemma2B>::Generate(
|
|||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
@ -1490,10 +1505,11 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
|||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
@ -1502,10 +1518,11 @@ void GemmaImpl<ConfigGriffin2B>::Generate(
|
|||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
@ -1575,11 +1592,11 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
|||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
int verbosity, LayersOutputT* layers_output) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
accept_token, gen, verbosity, layers_output);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
|
|
@ -1591,7 +1608,8 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
|||
GenerateGemma(
|
||||
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
||||
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
|
||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
|
||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ namespace gcpp {
|
|||
|
||||
using GemmaWeightT = GEMMA_WEIGHT_T;
|
||||
using EmbedderInputT = hwy::bfloat16_t;
|
||||
using LayersOutputT = std::function<void(int, std::string, const float*, size_t)>;
|
||||
constexpr size_t kPrefillBatchSize = 16;
|
||||
constexpr bool kSystemPrompt = false;
|
||||
|
||||
|
|
@ -97,7 +98,7 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
|||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity);
|
||||
int verbosity, LayersOutputT* layers_output = nullptr);
|
||||
|
||||
// Convenience function for the common case:
|
||||
// - Bundle runtime parameters as RuntimeConfig
|
||||
|
|
|
|||
Loading…
Reference in New Issue