mirror of https://github.com/google/gemma.cpp.git
Extend LayersOutputFunc to take query index and auxillary int
PiperOrigin-RevId: 657574814
This commit is contained in:
parent
8b4915f321
commit
d37c088e44
|
|
@ -59,10 +59,13 @@ int Run(int argc, char** argv) {
|
||||||
env.MutableConfig().layers_output =
|
env.MutableConfig().layers_output =
|
||||||
prompt_args.layers_output.Empty()
|
prompt_args.layers_output.Empty()
|
||||||
? LayersOutputFunc()
|
? LayersOutputFunc()
|
||||||
: [&json_output](int pos, const std::string& key, const float* values,
|
: [&json_output](size_t query_idx, size_t pos, const std::string& key,
|
||||||
size_t values_len) {
|
int layer, const float* values, size_t values_len) {
|
||||||
std::vector<float> v{values, values + values_len};
|
const std::string& debug_key =
|
||||||
json_output[std::to_string(pos)][key] = v;
|
layer < 0 ? key : (key + "." + std::to_string(layer));
|
||||||
|
const std::vector<float> v{values, values + values_len};
|
||||||
|
json& json_base = json_output[std::to_string(query_idx)];
|
||||||
|
json_base[std::to_string(pos)][debug_key] = v;
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
|
const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
|
||||||
|
|
|
||||||
|
|
@ -805,8 +805,10 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
const size_t num_interleaved = num_tokens * num_queries;
|
const size_t num_interleaved = num_tokens * num_queries;
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||||
float token_f = tokens[token_idx];
|
const size_t query_idx = token_idx % num_queries;
|
||||||
layers_output(pos + token_idx, "Tokens", &token_f, 1);
|
const size_t logical_pos = (pos + token_idx) / num_queries;
|
||||||
|
const float token_f = tokens[token_idx];
|
||||||
|
layers_output(query_idx, logical_pos, "tokens", -1, &token_f, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
@ -821,9 +823,9 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
layer_weights, activations, kv_caches, pool);
|
layer_weights, activations, kv_caches, pool);
|
||||||
|
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
const std::string block_name = "blocks." + std::to_string(layer);
|
|
||||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||||
layers_output(pos + token_idx, block_name,
|
const size_t logical_pos = (pos + token_idx) / num_queries;
|
||||||
|
layers_output(token_idx % num_queries, logical_pos, "blocks", layer,
|
||||||
activations.x.Batch(token_idx), kModelDim);
|
activations.x.Batch(token_idx), kModelDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -833,7 +835,9 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
activations.x.All(), kModelDim);
|
activations.x.All(), kModelDim);
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||||
layers_output(pos + token_idx, "final_norm",
|
const size_t query_idx = token_idx % num_queries;
|
||||||
|
const size_t logical_pos = (pos + token_idx) / num_queries;
|
||||||
|
layers_output(query_idx, logical_pos, "final_norm", -1,
|
||||||
activations.x.Batch(token_idx), kModelDim);
|
activations.x.Batch(token_idx), kModelDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -49,13 +49,15 @@ using AcceptFunc = std::function<bool(int, float)>;
|
||||||
// If not empty, SampleFunc is called with the probability distribution for the
|
// If not empty, SampleFunc is called with the probability distribution for the
|
||||||
// next token, and its return value is used as the next generated token.
|
// next token, and its return value is used as the next generated token.
|
||||||
using SampleFunc = std::function<int(const float*, size_t)>;
|
using SampleFunc = std::function<int(const float*, size_t)>;
|
||||||
// Will be called for layers output with:
|
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||||
|
// - index of query within containing batch (if any); zero otherwise.
|
||||||
// - position in the tokens sequence
|
// - position in the tokens sequence
|
||||||
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
|
// - name of the data, e.g. "tokens", "blocks", "final_norm"
|
||||||
|
// - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer
|
||||||
// - pointer to the data array
|
// - pointer to the data array
|
||||||
// - size of the data array
|
// - size of the data array
|
||||||
using LayersOutputFunc =
|
using LayersOutputFunc =
|
||||||
std::function<void(int, const std::string&, const float*, size_t)>;
|
std::function<void(size_t, size_t, const std::string&, int, const float*, size_t)>;
|
||||||
|
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue