diff --git a/gemma/gemma.h b/gemma/gemma.h index 6971802..40cc82c 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -60,6 +60,8 @@ struct PerQuery { // Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`. struct AllQueries { + AllQueries() = default; + // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, const hwy::Span& kv_caches) { @@ -97,6 +99,9 @@ struct AllQueries { } } + void Reserve(size_t size) { per_query_.reserve(size); } + void Append(const PerQuery& query) { per_query_.push_back(query); } + size_t NumQueries() const { return per_query_.size(); } PerQuery& operator[](size_t query_idx) {