mirror of https://github.com/google/gemma.cpp.git
(1) Added QueryResultAndMetrics and BatchQueryModelWithMetrics to also return TimingInfo besides query results.
PiperOrigin-RevId: 810634261
This commit is contained in:
parent
fac8aac4cb
commit
4f0c633248
|
|
@ -78,16 +78,16 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||||
<< runtime_config_.max_generated_tokens
|
<< runtime_config_.max_generated_tokens
|
||||||
<< "\ttemperature: " << runtime_config_.temperature << "\n";
|
<< "\ttemperature: " << runtime_config_.temperature << "\n";
|
||||||
}
|
}
|
||||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity};
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
||||||
timing_info);
|
timing_info);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GemmaEnv::QueryModel(
|
void GemmaEnv::QueryModel(const std::vector<int>& tokens,
|
||||||
const std::vector<int>& tokens, const StreamFunc& stream_token) {
|
const StreamFunc& stream_token) {
|
||||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity};
|
||||||
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
||||||
runtime_config_.stream_token = stream_token;
|
runtime_config_.stream_token = stream_token;
|
||||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
||||||
|
|
@ -95,7 +95,7 @@ void GemmaEnv::QueryModel(
|
||||||
runtime_config_.stream_token = previous_stream_token;
|
runtime_config_.stream_token = previous_stream_token;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
|
||||||
const QueriesPromptTokens& queries_prompt,
|
const QueriesPromptTokens& queries_prompt,
|
||||||
const hwy::Span<const size_t>& prefix_end) {
|
const hwy::Span<const size_t>& prefix_end) {
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
|
|
@ -140,7 +140,13 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
|
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
|
||||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||||
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
|
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
|
||||||
return res;
|
return {res, timing_info};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
|
const QueriesPromptTokens& queries_prompt,
|
||||||
|
const hwy::Span<const size_t>& prefix_end) {
|
||||||
|
return BatchQueryModelWithMetrics(queries_prompt, prefix_end).query_results;
|
||||||
}
|
}
|
||||||
|
|
||||||
QueryResult GemmaEnv::QueryModel(const std::string& input) {
|
QueryResult GemmaEnv::QueryModel(const std::string& input) {
|
||||||
|
|
@ -148,7 +154,7 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) {
|
||||||
return QueryModel(prompt);
|
return QueryModel(prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
|
||||||
const std::vector<std::string>& prompt_strings) {
|
const std::vector<std::string>& prompt_strings) {
|
||||||
std::vector<PromptTokens> views;
|
std::vector<PromptTokens> views;
|
||||||
views.reserve(prompt_strings.size());
|
views.reserve(prompt_strings.size());
|
||||||
|
|
@ -161,7 +167,12 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
QueriesPromptTokens span_of_views(views.data(), views.size());
|
QueriesPromptTokens span_of_views(views.data(), views.size());
|
||||||
return BatchQueryModel(span_of_views);
|
return BatchQueryModelWithMetrics(span_of_views);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
|
const std::vector<std::string>& inputs) {
|
||||||
|
return BatchQueryModelWithMetrics(inputs).query_results;
|
||||||
}
|
}
|
||||||
|
|
||||||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,14 @@ struct QueryResult {
|
||||||
size_t response_start_pos = 0;
|
size_t response_start_pos = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Return type for batch query model calls with metrics.
|
||||||
|
struct QueryResultAndMetrics {
|
||||||
|
// The query results for each query in the batch.
|
||||||
|
std::vector<QueryResult> query_results;
|
||||||
|
// The timing information for the batch query.
|
||||||
|
TimingInfo timing_info;
|
||||||
|
};
|
||||||
|
|
||||||
// Convenience class to load a model and run inference.
|
// Convenience class to load a model and run inference.
|
||||||
class GemmaEnv {
|
class GemmaEnv {
|
||||||
public:
|
public:
|
||||||
|
|
@ -79,21 +87,30 @@ class GemmaEnv {
|
||||||
return string;
|
return string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Adds turn structure to input, tokenizes and calls the below overload.
|
||||||
|
QueryResult QueryModel(const std::string& input);
|
||||||
// Runs inference on the given input and returns the top-1 result string and
|
// Runs inference on the given input and returns the top-1 result string and
|
||||||
// the number of tokens that were generated.
|
// the number of tokens that were generated.
|
||||||
QueryResult QueryModel(const std::vector<int>& tokens);
|
QueryResult QueryModel(const std::vector<int>& tokens);
|
||||||
|
// Runs inference on the given input and calls the callback for each token.
|
||||||
|
void QueryModel(const std::vector<int>& tokens,
|
||||||
|
const StreamFunc& stream_token);
|
||||||
|
|
||||||
|
// Similar to the above, but runs inference on a batch of inputs.
|
||||||
|
std::vector<QueryResult> BatchQueryModel(
|
||||||
|
const std::vector<std::string>& inputs);
|
||||||
// The default prefix_end means "causal attention".
|
// The default prefix_end means "causal attention".
|
||||||
std::vector<QueryResult> BatchQueryModel(
|
std::vector<QueryResult> BatchQueryModel(
|
||||||
const QueriesPromptTokens& queries_prompt,
|
const QueriesPromptTokens& queries_prompt,
|
||||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
||||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
|
||||||
QueryResult QueryModel(const std::string& input);
|
|
||||||
std::vector<QueryResult> BatchQueryModel(
|
|
||||||
const std::vector<std::string>& prompt_strings);
|
|
||||||
|
|
||||||
// Runs inference on the given input and calls the callback for each token.
|
// Similar to the above, but returns timing information in addition to the
|
||||||
void QueryModel(const std::vector<int>& tokens,
|
// query results.
|
||||||
const StreamFunc& stream_token);
|
QueryResultAndMetrics BatchQueryModelWithMetrics(
|
||||||
|
const std::vector<std::string>& prompt_strings);
|
||||||
|
QueryResultAndMetrics BatchQueryModelWithMetrics(
|
||||||
|
const QueriesPromptTokens& queries_prompt,
|
||||||
|
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
||||||
|
|
||||||
// Runs inference on the given input and returns the cross entropy, a measure
|
// Runs inference on the given input and returns the cross entropy, a measure
|
||||||
// of how well the model predicts the correct output. It is the average
|
// of how well the model predicts the correct output. It is the average
|
||||||
|
|
|
||||||
|
|
@ -178,6 +178,7 @@ struct TimingInfo {
|
||||||
// be sure to populate prefill_start and generate_start before calling
|
// be sure to populate prefill_start and generate_start before calling
|
||||||
// NotifyGenerated.
|
// NotifyGenerated.
|
||||||
void NotifyGenerated(size_t batch_size) {
|
void NotifyGenerated(size_t batch_size) {
|
||||||
|
generation_steps += 1;
|
||||||
const bool is_first = (tokens_generated == 0);
|
const bool is_first = (tokens_generated == 0);
|
||||||
tokens_generated += batch_size;
|
tokens_generated += batch_size;
|
||||||
if (HWY_UNLIKELY(is_first)) {
|
if (HWY_UNLIKELY(is_first)) {
|
||||||
|
|
@ -224,6 +225,7 @@ struct TimingInfo {
|
||||||
double time_to_first_token = 0;
|
double time_to_first_token = 0;
|
||||||
double generate_duration = 0;
|
double generate_duration = 0;
|
||||||
size_t tokens_generated = 0;
|
size_t tokens_generated = 0;
|
||||||
|
size_t generation_steps = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// After construction, all methods are const and thread-compatible if using
|
// After construction, all methods are const and thread-compatible if using
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue