diff --git a/BUILD.bazel b/BUILD.bazel index 84de4c4..8b96b16 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -141,6 +141,7 @@ cc_test( ":kv_cache", ":mat", ":matmul", + ":query", ":threading_context", ":weights", "@googletest//:gtest_main", # buildcleaner: keep @@ -444,9 +445,9 @@ cc_test( ":gemma_lib", ":mat", ":ops", + ":query", ":test_util", ":threading_context", - ":zones", "@googletest//:gtest_main", # buildcleaner: keep "//compression:test_util", "//compression:types", @@ -536,6 +537,17 @@ cc_test( ], ) +cc_library( + name = "query", + hdrs = ["gemma/query.h"], + deps = [ + ":basics", + ":gemma_args", + ":kv_cache", + "@highway//:hwy", + ], +) + cc_library( name = "gemma_args", hdrs = ["gemma/gemma_args.h"], @@ -586,7 +598,7 @@ cc_library( ":matmul_env", ":model_store", ":ops", - ":test_util", + ":query", ":threading", ":threading_context", ":weights", @@ -620,6 +632,7 @@ cc_test( ":kv_cache", ":mat", ":matmul_env", + ":query", ":test_util", ":threading_context", ":weights", diff --git a/gemma/gemma.h b/gemma/gemma.h index 158c1ff..f503a1a 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -27,6 +27,7 @@ #include "gemma/gemma_args.h" #include "gemma/kv_cache.h" #include "gemma/model_store.h" +#include "gemma/query.h" #include "gemma/weights.h" #include "io/blob_store.h" #include "io/io.h" // Path @@ -39,160 +40,6 @@ namespace gcpp { -struct PerQuery { - PromptTokens prompt; - - // Position in the KV cache: initially zero for the first turn, or when - // multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`. - size_t mutable_pos; - // Allows computing the last prefill token as `mutable_pos - initial_pos`, - // which might differ from `prompt.size() - 1` for prefix-LM. - size_t initial_pos; - // Zero for causal attention, or the end of the prefix for prefix-LM style - // attention in Paligemma. - size_t prefix_end; - - KVCachePtr kv_cache; - - // Previous token generated for this query, or the last prompt token. Will be - // fed into the next Transformer() call. - int prev_token = 0; -}; - -// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`. -struct AllQueries { - AllQueries() = default; - - // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. - AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, - const hwy::Span& kv_caches) { - per_query_.reserve(kv_caches.size()); - for (size_t i = 0; i < kv_caches.size(); ++i) { - HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); - per_query_.push_back(PerQuery{ - .prompt = prompt, - .mutable_pos = pos, - .initial_pos = pos, - .prefix_end = prefix_end, - .kv_cache = kv_caches[i], - }); - } - } - - AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, - const hwy::Span& kv_caches) - : AllQueries(prompt, pos, prefix_end, - hwy::Span(ToKVCachePtrs(kv_caches))) {} - - // Batch of queries with initial position set to zero. Causal attention - // is requested via empty or all-zero `prefix_end`. - AllQueries( - const hwy::Span& prompts, - const hwy::Span& kv_caches, - const hwy::Span& prefix_end = hwy::Span()) { - HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); - per_query_.reserve(prompts.size()); - for (size_t i = 0; i < prompts.size(); ++i) { - HWY_ASSERT(kv_caches.size() == 0 || - kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); - per_query_.push_back(PerQuery{ - .prompt = prompts[i], - .mutable_pos = 0, - .initial_pos = 0, - .prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i], - .kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i], - }); - } - } - - AllQueries( - const hwy::Span& prompts, - const hwy::Span& kv_caches, - const hwy::Span& prefix_end = hwy::Span()) - : AllQueries(prompts, hwy::Span(ToKVCachePtrs(kv_caches)), - prefix_end) {} - - void Reserve(size_t size) { per_query_.reserve(size); } - void Append(const PerQuery& query) { per_query_.push_back(query); } - - size_t NumQueries() const { return per_query_.size(); } - - PerQuery& operator[](size_t query_idx) { - HWY_DASSERT(query_idx < NumQueries()); - return per_query_[query_idx]; - } - const PerQuery& operator[](size_t query_idx) const { - HWY_DASSERT(query_idx < NumQueries()); - return per_query_[query_idx]; - } - - private: - std::vector per_query_; -}; - -// View into AllQueries: either a batch of queries, or a single query for use -// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a -// reference to AllQueries. -class QBatch { - public: - QBatch(size_t start, size_t max_size, AllQueries& queries) - : start_(start), - max_size_(max_size), - queries_(queries), - size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { - HWY_ASSERT(max_size_ <= kMaxBatchSize); - HWY_DASSERT(size_ != 0); - HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); - for (int i = 0; i < size_; ++i) { - query_idx_.push_back(start_ + i); - } - } - - // Returns a single-query view starting at `qi` relative to this batch. - QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); } - - // How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. - size_t Size() const { return size_; } - - // Returns index for use with `AllQueries` and `BatchStreamToken`. - size_t QueryIdx(size_t qi) const { - HWY_DASSERT(qi < size_); - return query_idx_[qi]; - } - - // Accessor functions to bridge the previous SoA and current AoS layout. - const PromptTokens& Prompt(size_t qi) const { - return queries_[QueryIdx(qi)].prompt; - } - size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; } - size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; } - size_t InitialPos(size_t qi) const { - return queries_[QueryIdx(qi)].initial_pos; - } - size_t PrefixEnd(size_t qi) const { - return queries_[QueryIdx(qi)].prefix_end; - } - KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } - int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } - - // let query_idx_[to] point to the from in the queries_; this is only used if - // the slot in the QBatch is less than the number of queries. - void Insert(size_t from, size_t to) { - if (from == to) return; - HWY_ASSERT(!queries_[from].kv_cache.IsEmpty()); - HWY_ASSERT(queries_[to].kv_cache.IsEmpty()); - // Conceptually, insert from.query to location to. - query_idx_[to] = from; - } - - protected: - size_t start_; - size_t max_size_; - AllQueries& queries_; - std::vector query_idx_; - size_t size_; -}; - // Used for continuous batching. class ContinuousQBatch : public QBatch { public: diff --git a/gemma/query.h b/gemma/query.h new file mode 100644 index 0000000..9c50a06 --- /dev/null +++ b/gemma/query.h @@ -0,0 +1,186 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_ + +#include + +#include "gemma/gemma_args.h" +#include "gemma/kv_cache.h" +#include "util/basics.h" +#include "hwy/aligned_allocator.h" // Span +#include "hwy/base.h" + +namespace gcpp { + +struct PerQuery { + PromptTokens prompt; + + // Position in the KV cache: initially zero for the first turn, or when + // multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`. + size_t mutable_pos; + // Allows computing the last prefill token as `mutable_pos - initial_pos`, + // which might differ from `prompt.size() - 1` for prefix-LM. + size_t initial_pos; + // Zero for causal attention, or the end of the prefix for prefix-LM style + // attention in Paligemma. + size_t prefix_end; + + KVCachePtr kv_cache; + + // Previous token generated for this query, or the last prompt token. Will be + // fed into the next Transformer() call. + int prev_token = 0; +}; + +// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`. +struct AllQueries { + AllQueries() = default; + + // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. + AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, + const hwy::Span& kv_caches) { + per_query_.reserve(kv_caches.size()); + for (size_t i = 0; i < kv_caches.size(); ++i) { + HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); + per_query_.push_back(PerQuery{ + .prompt = prompt, + .mutable_pos = pos, + .initial_pos = pos, + .prefix_end = prefix_end, + .kv_cache = kv_caches[i], + }); + } + } + + AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, + const hwy::Span& kv_caches) + : AllQueries(prompt, pos, prefix_end, + hwy::Span(ToKVCachePtrs(kv_caches))) {} + + // Batch of queries with initial position set to zero. Causal attention + // is requested via empty or all-zero `prefix_end`. + AllQueries( + const hwy::Span& prompts, + const hwy::Span& kv_caches, + const hwy::Span& prefix_end = hwy::Span()) { + HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); + per_query_.reserve(prompts.size()); + for (size_t i = 0; i < prompts.size(); ++i) { + HWY_ASSERT(kv_caches.size() == 0 || + kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); + per_query_.push_back(PerQuery{ + .prompt = prompts[i], + .mutable_pos = 0, + .initial_pos = 0, + .prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i], + .kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i], + }); + } + } + + AllQueries( + const hwy::Span& prompts, + const hwy::Span& kv_caches, + const hwy::Span& prefix_end = hwy::Span()) + : AllQueries(prompts, hwy::Span(ToKVCachePtrs(kv_caches)), + prefix_end) {} + + void Reserve(size_t size) { per_query_.reserve(size); } + void Append(const PerQuery& query) { per_query_.push_back(query); } + + size_t NumQueries() const { return per_query_.size(); } + + PerQuery& operator[](size_t query_idx) { + HWY_DASSERT(query_idx < NumQueries()); + return per_query_[query_idx]; + } + const PerQuery& operator[](size_t query_idx) const { + HWY_DASSERT(query_idx < NumQueries()); + return per_query_[query_idx]; + } + + private: + std::vector per_query_; +}; + +// View into AllQueries: either a batch of queries, or a single query for use +// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a +// reference to AllQueries. +class QBatch { + public: + QBatch(size_t start, size_t max_size, AllQueries& queries) + : start_(start), + max_size_(max_size), + queries_(queries), + size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { + HWY_ASSERT(max_size_ <= kMaxBatchSize); + HWY_DASSERT(size_ != 0); + HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); + query_idx_.reserve(size_); + for (int i = 0; i < size_; ++i) { + query_idx_.push_back(start_ + i); + } + } + + // Returns a single-query view starting at `qi` relative to this batch. + QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); } + + // How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. + size_t Size() const { return size_; } + + // Returns index for use with `AllQueries` and `BatchStreamToken`. + size_t QueryIdx(size_t qi) const { + HWY_DASSERT(qi < size_); + return query_idx_[qi]; + } + + // Accessor functions to bridge the previous SoA and current AoS layout. + const PromptTokens& Prompt(size_t qi) const { + return queries_[QueryIdx(qi)].prompt; + } + size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; } + size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; } + size_t InitialPos(size_t qi) const { + return queries_[QueryIdx(qi)].initial_pos; + } + size_t PrefixEnd(size_t qi) const { + return queries_[QueryIdx(qi)].prefix_end; + } + KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } + int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } + + // let query_idx_[to] point to the from in the queries_; this is only used if + // the slot in the QBatch is less than the number of queries. + void Insert(size_t from, size_t to) { + if (from == to) return; + HWY_ASSERT(!queries_[from].kv_cache.IsEmpty()); + HWY_ASSERT(queries_[to].kv_cache.IsEmpty()); + // Conceptually, insert from.query to location to. + query_idx_[to] = from; + } + + protected: + size_t start_; + size_t max_size_; + AllQueries& queries_; + std::vector query_idx_; + size_t size_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_