Fixed minor things and added comments.

This commit is contained in:
Andrey Mikhaylov 2024-04-12 14:54:34 +00:00
parent 2c5706f159
commit 4ef3da733a
4 changed files with 6 additions and 8 deletions

View File

@ -176,4 +176,4 @@ cc_binary(
"@hwy//:thread_pool", "@hwy//:thread_pool",
"@nlohmann_json//:json", "@nlohmann_json//:json",
], ],
) )

View File

@ -5,9 +5,7 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
// copybara:import_next_line:gemma_cpp
#include "util/app.h" #include "util/app.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" #include "util/args.h"
using json = nlohmann::json; using json = nlohmann::json;
@ -67,7 +65,7 @@ class OutputJsonLogger {
json json_output; json json_output;
gcpp::LayersOutputT layers_output_log_f = gcpp::LayersOutputT layers_output_log_f =
[this](int pos, std::string key, const float* values, size_t values_len) { [this](int pos, const std::string& key, const float* values, size_t values_len) {
std::vector<float> v{values, values + values_len}; std::vector<float> v{values, values + values_len};
json_output[std::to_string(pos)][key] = v; json_output[std::to_string(pos)][key] = v;
}; };
@ -132,4 +130,4 @@ int main(int argc, char** argv) {
} }
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }

View File

@ -1180,7 +1180,7 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
total_entropy / std::log(2.0) / (pos + 1)); total_entropy / std::log(2.0) / (pos + 1));
} }
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
nullptr); /*layers_output=*/nullptr);
MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0, MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0,
activations.x.data(), activations.x.data(),
activations.logits.data(), pool); activations.logits.data(), pool);
@ -1609,7 +1609,7 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
stream_token, [](int) { return true; }, gen, runtime_config.verbosity, stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
nullptr); /*layers_output=*/nullptr);
} }
void CompressWeights(gcpp::Model model, const Path& weights, void CompressWeights(gcpp::Model model, const Path& weights,

View File

@ -38,7 +38,7 @@ using EmbedderInputT = hwy::bfloat16_t;
// - ponter to the data array // - ponter to the data array
// - size of the data array // - size of the data array
using LayersOutputT = using LayersOutputT =
std::function<void(int, std::string, const float*, size_t)>; std::function<void(int, const std::string&, const float*, size_t)>;
constexpr size_t kPrefillBatchSize = 16; constexpr size_t kPrefillBatchSize = 16;
constexpr bool kSystemPrompt = false; constexpr bool kSystemPrompt = false;