Implement `start_pos` per query for batch interface

This commit is contained in:
RangerUFO 2024-08-10 18:24:15 +08:00 committed by Jan Wassenberg
parent 8e028632f7
commit 730b6bfc94
4 changed files with 91 additions and 55 deletions

View File

@ -171,7 +171,8 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity}; gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity};
runtime_config_.batch_stream_token = batch_stream_token; runtime_config_.batch_stream_token = batch_stream_token;
inference_args_.CopyTo(runtime_config_); inference_args_.CopyTo(runtime_config_);
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, model_->GenerateBatch(runtime_config_, prompts,
std::vector<size_t>(num_queries, 0),
KVCaches(&kv_caches_[0], num_queries), timing_info); KVCaches(&kv_caches_[0], num_queries), timing_info);
return res; return res;
} }

View File

@ -233,7 +233,7 @@ class GemmaAttention {
// Fills activations.q and computes KV. For kIsMHA, a single MatMul suffices // Fills activations.q and computes KV. For kIsMHA, a single MatMul suffices
// and we later copy KV from q to KVCache. Otherwise, a second MatMul writes // and we later copy KV from q to KVCache. Otherwise, a second MatMul writes
// KV directly to KVCache. // KV directly to KVCache.
HWY_NOINLINE void ComputeQKV(const size_t batch_start, HWY_NOINLINE void ComputeQKV(const MultiplePositions& batch_start,
const size_t num_interleaved) { const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.QKV"); PROFILER_ZONE("Gen.Attention.QKV");
// For the computation of Q, K, and V, it is useful to remember that // For the computation of Q, K, and V, it is useful to remember that
@ -255,9 +255,9 @@ class GemmaAttention {
// Single query and no wraparound means we can use a matmul and write // Single query and no wraparound means we can use a matmul and write
// directly into the KV cache with a stride of kCachePosSize. // directly into the KV cache with a stride of kCachePosSize.
if (num_queries_ == 1 && if (num_queries_ == 1 &&
batch_start + num_tokens_ <= div_seq_len_.GetDivisor()) { batch_start[0] + num_tokens_ <= div_seq_len_.GetDivisor()) {
const size_t kv_ofs = const size_t kv_ofs =
batch_start * kCachePosSize + layer_ * kCacheLayerSize; batch_start[0] * kCachePosSize + layer_ * kCacheLayerSize;
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
MatMul_4x4</*kAdd=*/false>( MatMul_4x4</*kAdd=*/false>(
@ -275,7 +275,7 @@ class GemmaAttention {
const size_t batch_idx = interleaved_idx / num_queries_; const size_t batch_idx = interleaved_idx / num_queries_;
KVCache& kv_cache = kv_caches_[query_idx]; KVCache& kv_cache = kv_caches_[query_idx];
const size_t cache_pos = const size_t cache_pos =
div_seq_len_.Remainder(batch_start + batch_idx); div_seq_len_.Remainder(batch_start[query_idx] + batch_idx);
const size_t kv_offset = const size_t kv_offset =
cache_pos * kCachePosSize + layer_ * kCacheLayerSize; cache_pos * kCachePosSize + layer_ * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
@ -295,7 +295,7 @@ class GemmaAttention {
const size_t interleaved_idx = task / kKVHeads; const size_t interleaved_idx = task / kKVHeads;
const size_t query_idx = interleaved_idx % num_queries_; const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_; const size_t batch_idx = interleaved_idx / num_queries_;
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start[query_idx] + batch_idx;
const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset = cache_pos * kCachePosSize + const size_t kv_offset = cache_pos * kCachePosSize +
layer_ * kCacheLayerSize + layer_ * kCacheLayerSize +
@ -374,7 +374,7 @@ class GemmaAttention {
} }
} }
HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t batch_start, HWY_NOINLINE void DotSoftmaxWeightedSum(const MultiplePositions& batch_start,
const size_t num_interleaved) { const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.DotSoftmax"); PROFILER_ZONE("Gen.Attention.DotSoftmax");
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>(); GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
@ -398,7 +398,7 @@ class GemmaAttention {
activations_.q.Batch(interleaved_idx) + head * kQStride; activations_.q.Batch(interleaved_idx) + head * kQStride;
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start[query_idx] + batch_idx;
PositionalEncodingQK(q, pos, layer_, kQueryScale, q); PositionalEncodingQK(q, pos, layer_, kQueryScale, q);
const size_t start_pos = StartPos(pos, layer_); const size_t start_pos = StartPos(pos, layer_);
@ -440,13 +440,12 @@ class GemmaAttention {
} }
public: public:
GemmaAttention(size_t interleaved_start, size_t num_tokens, GemmaAttention(const MultiplePositions& interleaved_start, size_t num_tokens,
size_t num_queries, size_t layer, Activations& activations, size_t num_queries, size_t layer, Activations& activations,
const CompressedLayer<TConfig>* layer_weights, const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
hwy::ThreadPool& pool) hwy::ThreadPool& pool)
: interleaved_start_(interleaved_start), : num_tokens_(num_tokens),
num_tokens_(num_tokens),
num_queries_(num_queries), num_queries_(num_queries),
layer_(layer), layer_(layer),
activations_(activations), activations_(activations),
@ -454,12 +453,22 @@ class GemmaAttention {
div_seq_len_(div_seq_len), div_seq_len_(div_seq_len),
kv_caches_(kv_caches), kv_caches_(kv_caches),
pool_(pool) { pool_(pool) {
HWY_DASSERT(interleaved_start_ % num_queries_ == 0); HWY_DASSERT(std::all_of(interleaved_start.cbegin(),
interleaved_start.cend(), [this](size_t pos) {
return pos % num_queries_ == 0;
}));
HWY_DASSERT(num_queries_ <= kv_caches_.size()); HWY_DASSERT(num_queries_ <= kv_caches_.size());
batch_start_.reserve(interleaved_start.size());
for (auto i = interleaved_start.cbegin();
i != interleaved_start.cend(); ++i) {
batch_start_.emplace_back(*i / num_queries_);
}
} }
HWY_INLINE void operator()() { HWY_INLINE void operator()() {
const size_t batch_start = interleaved_start_ / num_queries_; const MultiplePositions batch_start(batch_start_.data(),
batch_start_.size());
const size_t num_interleaved = num_tokens_ * num_queries_; const size_t num_interleaved = num_tokens_ * num_queries_;
ComputeQKV(batch_start, num_interleaved); ComputeQKV(batch_start, num_interleaved);
@ -468,7 +477,7 @@ class GemmaAttention {
} }
private: private:
const size_t interleaved_start_; std::vector<size_t> batch_start_;
const size_t num_tokens_; const size_t num_tokens_;
const size_t num_queries_; const size_t num_queries_;
const size_t layer_; const size_t layer_;
@ -480,7 +489,8 @@ class GemmaAttention {
}; };
template <class TConfig> template <class TConfig>
HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start, HWY_NOINLINE void Attention(LayerAttentionType type,
const MultiplePositions& interleaved_start,
size_t num_tokens, size_t num_queries, size_t layer, size_t num_tokens, size_t num_queries, size_t layer,
Activations& activations, Activations& activations,
const CompressedLayer<TConfig>* layer_weights, const CompressedLayer<TConfig>* layer_weights,
@ -495,7 +505,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
// this code for non-Griffin models. // this code for non-Griffin models.
if constexpr (TConfig::kGriffinLayers > 0) { if constexpr (TConfig::kGriffinLayers > 0) {
HWY_ASSERT(num_queries == 1); HWY_ASSERT(num_queries == 1);
GriffinRecurrent<TConfig>(interleaved_start, num_tokens, num_queries, GriffinRecurrent<TConfig>(interleaved_start[0], num_tokens, num_queries,
layer, activations, layer_weights, kv_caches, layer, activations, layer_weights, kv_caches,
pool); pool);
} }
@ -599,10 +609,10 @@ void PostNorm(size_t num_interleaved, const WeightT& weights, InOutT* inout) {
template <class TConfig> template <class TConfig>
HWY_NOINLINE void TransformerLayer( HWY_NOINLINE void TransformerLayer(
size_t num_tokens, size_t num_queries, size_t pos, size_t layer, size_t num_tokens, size_t num_queries, const MultiplePositions& pos,
const CompressedLayer<TConfig>* layer_weights, Activations& activations, size_t layer, const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, Activations& activations, const hwy::Divisor& div_seq_len,
hwy::ThreadPool& pool) { const KVCaches& kv_caches, hwy::ThreadPool& pool) {
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
const size_t num_interleaved = num_tokens * num_queries; const size_t num_interleaved = num_tokens * num_queries;
auto type = TConfig::kLayerConfig[layer]; auto type = TConfig::kLayerConfig[layer];
@ -688,7 +698,8 @@ class PrefillState {
template <class TConfig> template <class TConfig>
HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts, HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts,
const size_t prefill_per_query, const size_t pos, const size_t prefill_per_query,
const MultiplePositions& pos,
const size_t query_idx_start, const size_t query_idx_start,
const CompressedWeights<TConfig>& weights, const CompressedWeights<TConfig>& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
@ -719,14 +730,19 @@ class PrefillState {
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start); HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
for (size_t ti = 0; ti < tbatch_size; ++ti) { for (size_t ti = 0; ti < tbatch_size; ++ti) {
const int token = prompts[qi][tbatch_start + ti]; const int token = prompts[qi][tbatch_start + ti];
EmbedToken<TConfig>(token, ti, pos + ti, weights, activations.x); EmbedToken<TConfig>(token, ti, pos[qi] + ti, weights,
activations.x);
} }
const size_t tbatch_pos = pos[qi] + tbatch_start;
const MultiplePositions prefill_tbatch_pos(&tbatch_pos,
kPrefillQueries);
// Transformer with one batch of tokens from a single query. // Transformer with one batch of tokens from a single query.
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const auto* layer_weights = weights.GetLayer(layer); const auto* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig>(tbatch_size, kPrefillQueries, TransformerLayer<TConfig>(tbatch_size, kPrefillQueries,
pos + tbatch_start, layer, prefill_tbatch_pos, layer,
layer_weights, activations, div_seq_len, layer_weights, activations, div_seq_len,
prefill_kv_caches, inner_pool); prefill_kv_caches, inner_pool);
} }
@ -735,7 +751,7 @@ class PrefillState {
for (size_t ti = 0; ti < tbatch_size; ++ti) { for (size_t ti = 0; ti < tbatch_size; ++ti) {
const int token = prompts[qi][tbatch_start + ti]; const int token = prompts[qi][tbatch_start + ti];
runtime_config.StreamToken(query_idx_start + qi, runtime_config.StreamToken(query_idx_start + qi,
pos + tbatch_start + ti, token, 0.0f); tbatch_pos + ti, token, 0.0f);
} }
} // for tbatch_start } // for tbatch_start
}); });
@ -749,7 +765,7 @@ class PrefillState {
// `num_tokens == 1`. // `num_tokens == 1`.
template <class TConfig> template <class TConfig>
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
size_t num_queries, size_t pos, size_t num_queries, const MultiplePositions& pos,
const CompressedWeights<TConfig>& weights, const CompressedWeights<TConfig>& weights,
Activations& activations, Activations& activations,
const hwy::Divisor& div_seq_len, const hwy::Divisor& div_seq_len,
@ -759,14 +775,15 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
if (layers_output) { if (layers_output) {
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
const size_t query_idx = token_idx % num_queries; const size_t query_idx = token_idx % num_queries;
const size_t logical_pos = (pos + token_idx) / num_queries; const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries;
const float token_f = tokens[token_idx]; const float token_f = tokens[token_idx];
layers_output(query_idx, logical_pos, "tokens", -1, &token_f, 1); layers_output(query_idx, logical_pos, "tokens", -1, &token_f, 1);
} }
} }
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights, const size_t query_idx = token_idx % num_queries;
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos[query_idx], weights,
activations.x); activations.x);
} }
@ -778,7 +795,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
if (layers_output) { if (layers_output) {
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
const size_t logical_pos = (pos + token_idx) / num_queries; const size_t query_idx = token_idx % num_queries;
const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries;
layers_output(token_idx % num_queries, logical_pos, "blocks", layer, layers_output(token_idx % num_queries, logical_pos, "blocks", layer,
activations.x.Batch(token_idx), kModelDim); activations.x.Batch(token_idx), kModelDim);
} }
@ -790,7 +808,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
if (layers_output) { if (layers_output) {
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
const size_t query_idx = token_idx % num_queries; const size_t query_idx = token_idx % num_queries;
const size_t logical_pos = (pos + token_idx) / num_queries; const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries;
layers_output(query_idx, logical_pos, "final_norm", -1, layers_output(query_idx, logical_pos, "final_norm", -1,
activations.x.Batch(token_idx), kModelDim); activations.x.Batch(token_idx), kModelDim);
} }
@ -897,9 +915,10 @@ class TokenStreamer {
template <class TConfig> template <class TConfig>
void GenerateT(const ByteStorageT& weights_u8, Activations& activations, void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, const size_t pos, const MultiplePromptsTokens& prompts,
const size_t query_idx_start, const KVCaches& kv_caches, const MultiplePositions& pos, const size_t query_idx_start,
PerClusterPools& pools, TimingInfo& timing_info) { const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) {
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kVocabSize = TConfig::kVocabSize; constexpr size_t kVocabSize = TConfig::kVocabSize;
const CompressedWeights<TConfig>& weights = const CompressedWeights<TConfig>& weights =
@ -921,10 +940,12 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
size_t max_tokens = runtime_config.max_tokens; size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size); RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
if (pos >= max_tokens) { for (auto i = pos.cbegin(); i != pos.cend(); ++i) {
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos, if (*i >= max_tokens) {
max_tokens); fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", *i,
return; max_tokens);
return;
}
} }
// If no sample_func is provided, we use top-k sampling. // If no sample_func is provided, we use top-k sampling.
@ -953,7 +974,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start); timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
} }
size_t interleaved_pos = (pos + prefill_per_query) * num_queries; std::vector<size_t> interleaved_pos(pos.size());
std::transform(
pos.cbegin(), pos.cend(), interleaved_pos.begin(),
[&](size_t v) { return (v + prefill_per_query) * num_queries; });
// Storage for the last generated token from each query, passed to the next // Storage for the last generated token from each query, passed to the next
// Transformer() call. // Transformer() call.
@ -972,10 +996,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
gen_per_query < HWY_MIN(max_tokens, max_generated_tokens); gen_per_query < HWY_MIN(max_tokens, max_generated_tokens);
++gen_per_query) { ++gen_per_query) {
// Decode: generate one token for each query. // Decode: generate one token for each query.
Transformer<TConfig>(gen_tokens.data(), /*num_tokens=*/1, num_queries, Transformer<TConfig>(
interleaved_pos, weights, activations, div_seq_len, gen_tokens.data(), /*num_tokens=*/1, num_queries,
kv_caches, pool, runtime_config.layers_output); MultiplePositions(interleaved_pos.data(), interleaved_pos.size()),
interleaved_pos += num_queries; weights, activations, div_seq_len, kv_caches, pool,
runtime_config.layers_output);
for (auto& v: interleaved_pos) {
v += num_queries;
}
bool all_queries_eos = true; bool all_queries_eos = true;
PROFILER_ZONE("Gen.Embedding"); PROFILER_ZONE("Gen.Embedding");
@ -1016,19 +1044,22 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
activations.Allocate<TConfig>(num_queries); activations.Allocate<TConfig>(num_queries);
const MultiplePromptsTokens prompts(&prompt, num_queries); const MultiplePromptsTokens prompts(&prompt, num_queries);
const MultiplePositions positions(&pos, num_queries);
const KVCaches kv_caches{&kv_cache, num_queries}; const KVCaches kv_caches{&kv_cache, num_queries};
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts, pos, GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts,
qbatch_start, kv_caches, pools, timing_info); positions, qbatch_start, kv_caches, pools, timing_info);
} }
template <class TConfig> template <class TConfig>
void GenerateBatchT(const ByteStorageT& weights_u8, void GenerateBatchT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, size_t pos, const MultiplePromptsTokens& prompts,
const MultiplePositions& pos,
const KVCaches& kv_caches, PerClusterPools& pools, const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) { TimingInfo& timing_info) {
HWY_ASSERT(prompts.size() == kv_caches.size()); HWY_ASSERT(prompts.size() == pos.size() &&
prompts.size() == kv_caches.size());
// Griffin does not support query batching. // Griffin does not support query batching.
const size_t max_qbatch_size = const size_t max_qbatch_size =
(TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size; (TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size;
@ -1044,9 +1075,10 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
HWY_MIN(num_queries - qbatch_start, max_qbatch_size); HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start], const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start],
qbatch_size); qbatch_size);
const MultiplePositions qbatch_pos(&pos[qbatch_start], qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts, GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
pos, qbatch_start, qbatch_kv, pools, timing_info); qbatch_pos, qbatch_start, qbatch_kv, pools, timing_info);
} }
} }
@ -1067,8 +1099,8 @@ void GenerateSingle( // NOLINT(misc-definitions-in-headers)
void GenerateBatch( // NOLINT(misc-definitions-in-headers) void GenerateBatch( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8, GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
size_t pos, const KVCaches& kv_caches, PerClusterPools& pools, const MultiplePositions& pos, const KVCaches& kv_caches,
TimingInfo& timing_info) { PerClusterPools& pools, TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>) HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
(weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info); (weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info);
} }

View File

@ -66,7 +66,8 @@ Gemma::~Gemma() {
TimingInfo& timing_info); \ TimingInfo& timing_info); \
extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \ extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const RuntimeConfig& runtime_config, \ const RuntimeConfig& runtime_config, \
const MultiplePromptsTokens& prompts, size_t pos, \ const MultiplePromptsTokens& prompts, \
const MultiplePositions& pos, \
const KVCaches& kv_caches, PerClusterPools& pools, \ const KVCaches& kv_caches, PerClusterPools& pools, \
TimingInfo& timing_info); TimingInfo& timing_info);
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
@ -87,9 +88,9 @@ template <class TConfig>
struct GenerateBatchT { struct GenerateBatchT {
void operator()(const ByteStorageT& weights_u8, void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, size_t pos, const MultiplePromptsTokens& prompts,
const KVCaches& kv_caches, PerClusterPools& pools, const MultiplePositions& pos, const KVCaches& kv_caches,
TimingInfo& timing_info) const { PerClusterPools& pools, TimingInfo& timing_info) const {
GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos, GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos,
kv_caches, pools, timing_info); kv_caches, pools, timing_info);
} }
@ -109,8 +110,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, const MultiplePromptsTokens& prompts,
size_t start_pos, const KVCaches& kv_caches, const MultiplePositions& start_pos,
TimingInfo& timing_info) { const KVCaches& kv_caches, TimingInfo& timing_info) {
pools_.StartSpinning(); pools_.StartSpinning();
CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_, CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_,

View File

@ -143,6 +143,7 @@ struct TimingInfo {
using PromptTokens = hwy::Span<const int>; using PromptTokens = hwy::Span<const int>;
using MultiplePromptsTokens = hwy::Span<const PromptTokens>; using MultiplePromptsTokens = hwy::Span<const PromptTokens>;
using MultiplePositions = hwy::Span<const size_t>;
using KVCaches = hwy::Span<KVCache>; using KVCaches = hwy::Span<KVCache>;
class Gemma { class Gemma {
@ -164,7 +165,8 @@ class Gemma {
size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info); size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info);
void GenerateBatch(const RuntimeConfig& runtime_config, void GenerateBatch(const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, size_t start_pos, const MultiplePromptsTokens& prompts,
const MultiplePositions& start_pos,
const KVCaches& kv_caches, TimingInfo& timing_info); const KVCaches& kv_caches, TimingInfo& timing_info);
private: private: