mirror of https://github.com/google/gemma.cpp.git
(1) Added QueryResultAndMetrics and BatchQueryModelWithMetrics to also return TimingInfo besides query results.
PiperOrigin-RevId: 810634261
This commit is contained in:
parent
fac8aac4cb
commit
4f0c633248
|
|
@ -78,16 +78,16 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
|||
<< runtime_config_.max_generated_tokens
|
||||
<< "\ttemperature: " << runtime_config_.temperature << "\n";
|
||||
}
|
||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
||||
timing_info);
|
||||
return result;
|
||||
}
|
||||
|
||||
void GemmaEnv::QueryModel(
|
||||
const std::vector<int>& tokens, const StreamFunc& stream_token) {
|
||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
void GemmaEnv::QueryModel(const std::vector<int>& tokens,
|
||||
const StreamFunc& stream_token) {
|
||||
gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity};
|
||||
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
||||
runtime_config_.stream_token = stream_token;
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
||||
|
|
@ -95,7 +95,7 @@ void GemmaEnv::QueryModel(
|
|||
runtime_config_.stream_token = previous_stream_token;
|
||||
}
|
||||
|
||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const hwy::Span<const size_t>& prefix_end) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
|
|
@ -140,7 +140,13 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
|
||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
|
||||
return res;
|
||||
return {res, timing_info};
|
||||
}
|
||||
|
||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const hwy::Span<const size_t>& prefix_end) {
|
||||
return BatchQueryModelWithMetrics(queries_prompt, prefix_end).query_results;
|
||||
}
|
||||
|
||||
QueryResult GemmaEnv::QueryModel(const std::string& input) {
|
||||
|
|
@ -148,7 +154,7 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) {
|
|||
return QueryModel(prompt);
|
||||
}
|
||||
|
||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
|
||||
const std::vector<std::string>& prompt_strings) {
|
||||
std::vector<PromptTokens> views;
|
||||
views.reserve(prompt_strings.size());
|
||||
|
|
@ -161,7 +167,12 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
}
|
||||
|
||||
QueriesPromptTokens span_of_views(views.data(), views.size());
|
||||
return BatchQueryModel(span_of_views);
|
||||
return BatchQueryModelWithMetrics(span_of_views);
|
||||
}
|
||||
|
||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||
const std::vector<std::string>& inputs) {
|
||||
return BatchQueryModelWithMetrics(inputs).query_results;
|
||||
}
|
||||
|
||||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||
|
|
|
|||
|
|
@ -39,6 +39,14 @@ struct QueryResult {
|
|||
size_t response_start_pos = 0;
|
||||
};
|
||||
|
||||
// Return type for batch query model calls with metrics.
|
||||
struct QueryResultAndMetrics {
|
||||
// The query results for each query in the batch.
|
||||
std::vector<QueryResult> query_results;
|
||||
// The timing information for the batch query.
|
||||
TimingInfo timing_info;
|
||||
};
|
||||
|
||||
// Convenience class to load a model and run inference.
|
||||
class GemmaEnv {
|
||||
public:
|
||||
|
|
@ -79,21 +87,30 @@ class GemmaEnv {
|
|||
return string;
|
||||
}
|
||||
|
||||
// Adds turn structure to input, tokenizes and calls the below overload.
|
||||
QueryResult QueryModel(const std::string& input);
|
||||
// Runs inference on the given input and returns the top-1 result string and
|
||||
// the number of tokens that were generated.
|
||||
QueryResult QueryModel(const std::vector<int>& tokens);
|
||||
// Runs inference on the given input and calls the callback for each token.
|
||||
void QueryModel(const std::vector<int>& tokens,
|
||||
const StreamFunc& stream_token);
|
||||
|
||||
// Similar to the above, but runs inference on a batch of inputs.
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
const std::vector<std::string>& inputs);
|
||||
// The default prefix_end means "causal attention".
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||
QueryResult QueryModel(const std::string& input);
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
const std::vector<std::string>& prompt_strings);
|
||||
|
||||
// Runs inference on the given input and calls the callback for each token.
|
||||
void QueryModel(const std::vector<int>& tokens,
|
||||
const StreamFunc& stream_token);
|
||||
// Similar to the above, but returns timing information in addition to the
|
||||
// query results.
|
||||
QueryResultAndMetrics BatchQueryModelWithMetrics(
|
||||
const std::vector<std::string>& prompt_strings);
|
||||
QueryResultAndMetrics BatchQueryModelWithMetrics(
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
||||
|
||||
// Runs inference on the given input and returns the cross entropy, a measure
|
||||
// of how well the model predicts the correct output. It is the average
|
||||
|
|
|
|||
|
|
@ -178,6 +178,7 @@ struct TimingInfo {
|
|||
// be sure to populate prefill_start and generate_start before calling
|
||||
// NotifyGenerated.
|
||||
void NotifyGenerated(size_t batch_size) {
|
||||
generation_steps += 1;
|
||||
const bool is_first = (tokens_generated == 0);
|
||||
tokens_generated += batch_size;
|
||||
if (HWY_UNLIKELY(is_first)) {
|
||||
|
|
@ -224,6 +225,7 @@ struct TimingInfo {
|
|||
double time_to_first_token = 0;
|
||||
double generate_duration = 0;
|
||||
size_t tokens_generated = 0;
|
||||
size_t generation_steps = 0;
|
||||
};
|
||||
|
||||
// After construction, all methods are const and thread-compatible if using
|
||||
|
|
|
|||
Loading…
Reference in New Issue