mirror of https://github.com/google/gemma.cpp.git
Simplify pos handling, auto-increment output arg
- no longer multiply by num_queries - remove unused interleaved prompts - Rename to Queries* - Rename batch_start/interleaved_pos/pos to queries_pos PiperOrigin-RevId: 663331823
This commit is contained in:
parent
6763afcd1c
commit
22995c699d
|
|
@ -108,10 +108,9 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
std::string res;
|
||||
size_t total_tokens = 0;
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
const BatchStreamFunc batch_stream_token =
|
||||
[&res, &total_tokens, &time_start, this](
|
||||
size_t query_index, size_t pos, int token, float) {
|
||||
const BatchStreamFunc batch_stream_token = [&res, &total_tokens, this](
|
||||
size_t query_index, size_t pos,
|
||||
int token, float) {
|
||||
++total_tokens;
|
||||
res += StringFromTokens(std::vector<int>{token});
|
||||
return true;
|
||||
|
|
@ -130,23 +129,19 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
|||
}
|
||||
|
||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||
const MultiplePromptsTokens& prompts) {
|
||||
const size_t num_queries = prompts.size();
|
||||
const QueriesPromptTokens& queries_prompt) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(num_queries != 0);
|
||||
std::vector<std::pair<std::string, size_t>> res(num_queries);
|
||||
std::fill(res.begin(), res.end(), std::make_pair("", 0));
|
||||
size_t total_tokens = 0;
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
const BatchStreamFunc batch_stream_token =
|
||||
[&res, &total_tokens, &time_start, this](
|
||||
size_t query_index, size_t pos, int token, float) {
|
||||
const BatchStreamFunc batch_stream_token = [&res, this](size_t query_index,
|
||||
size_t pos, int token,
|
||||
float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
res[query_index].first.append(token_text);
|
||||
res[query_index].second += 1;
|
||||
++total_tokens;
|
||||
return true;
|
||||
};
|
||||
if (app_.verbosity >= 2) {
|
||||
|
|
@ -171,8 +166,9 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
|||
gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
inference_args_.CopyTo(runtime_config_);
|
||||
model_->GenerateBatch(runtime_config_, prompts,
|
||||
std::vector<size_t>(num_queries, 0),
|
||||
std::vector<size_t> queries_pos(num_queries, 0);
|
||||
model_->GenerateBatch(runtime_config_, queries_prompt,
|
||||
QueriesPos(queries_pos.data(), num_queries),
|
||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||
return res;
|
||||
}
|
||||
|
|
@ -197,7 +193,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
|||
for (auto& prompt : prompts) {
|
||||
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
MultiplePromptsTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
||||
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
||||
return BatchQueryModel2(prompt_span);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class GemmaEnv {
|
|||
// the number of tokens that were generated.
|
||||
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel2(
|
||||
const MultiplePromptsTokens& prompts);
|
||||
const QueriesPromptTokens& queries_prompt);
|
||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||
std::pair<std::string, size_t> QueryModel(std::string& input);
|
||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class GemmaTest : public ::testing::Test {
|
|||
for (const auto& prompt : prompts_vector) {
|
||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
MultiplePromptsTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||
for (auto [response, n] : s_env->BatchQueryModel2(prompts)) {
|
||||
replies.push_back(response);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -110,13 +110,13 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
|
||||
std::vector<int> predicted_token_ids;
|
||||
predicted_token_ids.reserve(4096);
|
||||
size_t current_pos = 0;
|
||||
const StreamFunc stream_token = [¤t_pos, prompt_size,
|
||||
size_t generated = 0;
|
||||
const StreamFunc stream_token = [&generated, prompt_size,
|
||||
&predicted_token_ids](int token,
|
||||
float proba) {
|
||||
PROFILER_ZONE("Stream");
|
||||
++current_pos;
|
||||
if (current_pos > prompt_size) {
|
||||
++generated;
|
||||
if (generated > prompt_size) {
|
||||
predicted_token_ids.push_back(token);
|
||||
}
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ int main(int argc, char** argv) {
|
|||
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
||||
gcpp::KVCache kv_cache =
|
||||
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
|
||||
size_t pos = 0; // KV Cache position
|
||||
size_t generated = 0;
|
||||
|
||||
// Initialize random number generator
|
||||
std::mt19937 gen;
|
||||
|
|
@ -56,14 +56,14 @@ int main(int argc, char** argv) {
|
|||
|
||||
// Tokenize instructions.
|
||||
std::string prompt = "Write a greeting to the world.";
|
||||
const std::vector<int> tokens =
|
||||
gcpp::WrapAndTokenize(model.Tokenizer(), loader.Info(), pos, prompt);
|
||||
size_t ntokens = tokens.size();
|
||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||
model.Tokenizer(), loader.Info(), generated, prompt);
|
||||
const size_t prompt_size = tokens.size();
|
||||
|
||||
// This callback function gets invoked every time a token is generated
|
||||
auto stream_token = [&pos, &ntokens, &model](int token, float) {
|
||||
++pos;
|
||||
if (pos < ntokens) {
|
||||
auto stream_token = [&generated, &prompt_size, &model](int token, float) {
|
||||
++generated;
|
||||
if (generated < prompt_size) {
|
||||
// print feedback
|
||||
} else if (token != gcpp::EOS_ID) {
|
||||
std::string token_text;
|
||||
|
|
|
|||
|
|
@ -68,13 +68,13 @@ namespace HWY_NAMESPACE {
|
|||
// count `num_interleaved`. Functions that are query-dependent, such as
|
||||
// `Attention`, use separate `num_tokens` and `num_queries`.
|
||||
|
||||
// TODO: add batch query support for Griffin (QueriesPos).
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void GriffinRecurrent(
|
||||
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
|
||||
size_t batch_start, size_t num_tokens, size_t layer,
|
||||
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin.
|
||||
KVCache& kv_cache = kv_caches[0];
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
|
|
@ -233,8 +233,7 @@ class GemmaAttention {
|
|||
// Fills activations.q and computes KV. For kIsMHA, a single MatMul suffices
|
||||
// and we later copy KV from q to KVCache. Otherwise, a second MatMul writes
|
||||
// KV directly to KVCache.
|
||||
HWY_NOINLINE void ComputeQKV(const MultiplePositions& batch_start,
|
||||
const size_t num_interleaved) {
|
||||
HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
// For the computation of Q, K, and V, it is useful to remember that
|
||||
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
|
||||
|
|
@ -255,9 +254,9 @@ class GemmaAttention {
|
|||
// Single query and no wraparound means we can use a matmul and write
|
||||
// directly into the KV cache with a stride of kCachePosSize.
|
||||
if (num_queries_ == 1 &&
|
||||
batch_start[0] + num_tokens_ <= div_seq_len_.GetDivisor()) {
|
||||
queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) {
|
||||
const size_t kv_ofs =
|
||||
batch_start[0] * kCachePosSize + layer_ * kCacheLayerSize;
|
||||
queries_pos_[0] * kCachePosSize + layer_ * kCacheLayerSize;
|
||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||
MatMul_4x4</*kAdd=*/false>(
|
||||
|
|
@ -275,7 +274,7 @@ class GemmaAttention {
|
|||
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
const size_t cache_pos =
|
||||
div_seq_len_.Remainder(batch_start[query_idx] + batch_idx);
|
||||
div_seq_len_.Remainder(queries_pos_[query_idx] + batch_idx);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer_ * kCacheLayerSize;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
|
|
@ -295,7 +294,7 @@ class GemmaAttention {
|
|||
const size_t interleaved_idx = task / kKVHeads;
|
||||
const size_t query_idx = interleaved_idx % num_queries_;
|
||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||
const size_t pos = batch_start[query_idx] + batch_idx;
|
||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||
layer_ * kCacheLayerSize +
|
||||
|
|
@ -374,8 +373,7 @@ class GemmaAttention {
|
|||
}
|
||||
}
|
||||
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum(const MultiplePositions& batch_start,
|
||||
const size_t num_interleaved) {
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t num_interleaved) {
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
||||
|
||||
|
|
@ -398,7 +396,7 @@ class GemmaAttention {
|
|||
activations_.q.Batch(interleaved_idx) + head * kQStride;
|
||||
|
||||
// Apply rope and scaling to Q.
|
||||
const size_t pos = batch_start[query_idx] + batch_idx;
|
||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||
PositionalEncodingQK(q, pos, layer_, kQueryScale, q);
|
||||
|
||||
const size_t start_pos = StartPos(pos, layer_);
|
||||
|
|
@ -440,45 +438,34 @@ class GemmaAttention {
|
|||
}
|
||||
|
||||
public:
|
||||
GemmaAttention(const MultiplePositions& interleaved_start, size_t num_tokens,
|
||||
size_t num_queries, size_t layer, Activations& activations,
|
||||
GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
|
||||
Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool)
|
||||
: num_tokens_(num_tokens),
|
||||
num_queries_(num_queries),
|
||||
: queries_pos_(queries_pos),
|
||||
num_queries_(queries_pos.size()),
|
||||
num_tokens_(num_tokens),
|
||||
layer_(layer),
|
||||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
pool_(pool) {
|
||||
HWY_DASSERT(
|
||||
std::all_of(interleaved_start.cbegin(), interleaved_start.cend(),
|
||||
[this](size_t pos) { return pos % num_queries_ == 0; }));
|
||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||
|
||||
batch_start_.reserve(interleaved_start.size());
|
||||
for (auto i = interleaved_start.cbegin(); i != interleaved_start.cend();
|
||||
++i) {
|
||||
batch_start_.emplace_back(*i / num_queries_);
|
||||
}
|
||||
}
|
||||
|
||||
HWY_INLINE void operator()() {
|
||||
const MultiplePositions batch_start(batch_start_.data(),
|
||||
batch_start_.size());
|
||||
const size_t num_interleaved = num_tokens_ * num_queries_;
|
||||
|
||||
ComputeQKV(batch_start, num_interleaved);
|
||||
DotSoftmaxWeightedSum(batch_start, num_interleaved);
|
||||
ComputeQKV(num_interleaved);
|
||||
DotSoftmaxWeightedSum(num_interleaved);
|
||||
SumHeads(num_interleaved);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> batch_start_;
|
||||
const size_t num_tokens_;
|
||||
const QueriesPos& queries_pos_;
|
||||
const size_t num_queries_;
|
||||
const size_t num_tokens_;
|
||||
const size_t layer_;
|
||||
Activations& activations_;
|
||||
const CompressedLayer<TConfig>& layer_weights_;
|
||||
|
|
@ -489,24 +476,21 @@ class GemmaAttention {
|
|||
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Attention(LayerAttentionType type,
|
||||
const MultiplePositions& interleaved_start,
|
||||
size_t num_tokens, size_t num_queries, size_t layer,
|
||||
Activations& activations,
|
||||
const QueriesPos& queries_pos, size_t num_tokens,
|
||||
size_t layer, Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
|
||||
activations, layer_weights, div_seq_len, kv_caches,
|
||||
pool)();
|
||||
GemmaAttention<TConfig>(queries_pos, num_tokens, layer, activations,
|
||||
layer_weights, div_seq_len, kv_caches, pool)();
|
||||
} else {
|
||||
// Only reached if the model is Griffin. `if constexpr` prevents generating
|
||||
// this code for non-Griffin models.
|
||||
if constexpr (TConfig::kGriffinLayers > 0) {
|
||||
HWY_ASSERT(num_queries == 1);
|
||||
GriffinRecurrent<TConfig>(interleaved_start[0], num_tokens, num_queries,
|
||||
layer, activations, layer_weights, kv_caches,
|
||||
pool);
|
||||
HWY_ASSERT(queries_pos.size() == 1);
|
||||
GriffinRecurrent<TConfig>(queries_pos[0], num_tokens, layer, activations,
|
||||
layer_weights, kv_caches, pool);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -608,12 +592,12 @@ void PostNorm(size_t num_interleaved, const WeightT& weights, InOutT* inout) {
|
|||
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void TransformerLayer(
|
||||
size_t num_tokens, size_t num_queries, const MultiplePositions& pos,
|
||||
size_t layer, const CompressedLayer<TConfig>* layer_weights,
|
||||
Activations& activations, const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
||||
const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
|
||||
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
const size_t num_interleaved = num_tokens * num_queries;
|
||||
const size_t num_interleaved = num_tokens * queries_pos.size();
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
size_t layer_of_type =
|
||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||
|
|
@ -622,8 +606,8 @@ HWY_NOINLINE void TransformerLayer(
|
|||
layer_weights->pre_attention_norm_scale.data_scale1(),
|
||||
activations.pre_att_rms_out.All(), kModelDim);
|
||||
|
||||
Attention<TConfig>(type, pos, num_tokens, num_queries, layer_of_type,
|
||||
activations, layer_weights, div_seq_len, kv_caches, pool);
|
||||
Attention<TConfig>(type, queries_pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, div_seq_len, kv_caches, pool);
|
||||
|
||||
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
|
||||
activations.att_sums.All());
|
||||
|
|
@ -646,6 +630,9 @@ HWY_NOINLINE void TransformerLayer(
|
|||
/*is_attention=*/false);
|
||||
}
|
||||
|
||||
// Prefill and Transformer() advance positions in-place.
|
||||
using QueriesMutablePos = hwy::Span<size_t>;
|
||||
|
||||
// Batches are important for amortizing loading weights over multiple tokens.
|
||||
// This is possible in prefill because we know all tokens beforehand, whereas
|
||||
// decode depends on the previous output token. However, each prefill batch of a
|
||||
|
|
@ -696,16 +683,16 @@ class PrefillState {
|
|||
}
|
||||
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts,
|
||||
HWY_NOINLINE void Prefill(const QueriesPromptTokens& queries_prompt,
|
||||
const size_t prefill_per_query,
|
||||
const MultiplePositions& pos,
|
||||
const QueriesMutablePos& queries_pos,
|
||||
const size_t query_idx_start,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = prompts.size();
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
const size_t max_tbatch_size = activations_[0].x.BatchSize();
|
||||
|
||||
|
|
@ -716,10 +703,10 @@ class PrefillState {
|
|||
Activations& activations = activations_[qi];
|
||||
hwy::ThreadPool& inner_pool = pools.Inner(qthread);
|
||||
|
||||
// Single query at a time, so pass a slice of the KV cache because
|
||||
// GemmaAttention will only access the first.
|
||||
const size_t kPrefillQueries = 1;
|
||||
KVCaches prefill_kv_caches(&kv_caches[qi], kPrefillQueries);
|
||||
// Single query at a time, so pass slices of the spans because
|
||||
// GemmaAttention will only access the first KV cache and position.
|
||||
KVCaches single_kv_cache(&kv_caches[qi], 1);
|
||||
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
||||
|
||||
// For each batch of tokens in the query:
|
||||
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
|
||||
|
|
@ -728,30 +715,28 @@ class PrefillState {
|
|||
const size_t tbatch_size =
|
||||
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const int token = prompts[qi][tbatch_start + ti];
|
||||
EmbedToken<TConfig>(token, ti, pos[qi] + ti, weights,
|
||||
activations.x);
|
||||
const int token = queries_prompt[qi][tbatch_start + ti];
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
EmbedToken<TConfig>(token, ti, pos, weights, activations.x);
|
||||
}
|
||||
|
||||
const size_t tbatch_pos = pos[qi] + tbatch_start;
|
||||
const MultiplePositions prefill_tbatch_pos(&tbatch_pos,
|
||||
kPrefillQueries);
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig>(tbatch_size, kPrefillQueries,
|
||||
prefill_tbatch_pos, layer,
|
||||
TransformerLayer<TConfig>(single_query_pos, tbatch_size, layer,
|
||||
layer_weights, activations, div_seq_len,
|
||||
prefill_kv_caches, inner_pool);
|
||||
single_kv_cache, inner_pool);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const int token = prompts[qi][tbatch_start + ti];
|
||||
runtime_config.StreamToken(query_idx_start + qi, tbatch_pos + ti,
|
||||
token, 0.0f);
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
const int token = queries_prompt[qi][pos];
|
||||
runtime_config.StreamToken(query_idx_start + qi, pos, token,
|
||||
0.0f);
|
||||
}
|
||||
|
||||
queries_pos[qi] += tbatch_size;
|
||||
} // for tbatch_start
|
||||
});
|
||||
}
|
||||
|
|
@ -760,58 +745,60 @@ class PrefillState {
|
|||
std::vector<Activations> activations_; // One per query, filled by Init.
|
||||
};
|
||||
|
||||
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
|
||||
// `num_tokens == 1`.
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
// from each query, and `queries_pos` are their position in the sequence.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||
size_t num_queries, const MultiplePositions& pos,
|
||||
HWY_NOINLINE void Transformer(const QueriesToken& queries_token,
|
||||
const QueriesMutablePos& queries_pos,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
Activations& activations,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||
const LayersOutputFunc& layers_output) {
|
||||
const size_t num_interleaved = num_tokens * num_queries;
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
const size_t num_queries = queries_token.size();
|
||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
||||
|
||||
if (layers_output) {
|
||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||
const size_t query_idx = token_idx % num_queries;
|
||||
const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries;
|
||||
const float token_f = tokens[token_idx];
|
||||
layers_output(query_idx, logical_pos, "tokens", -1, &token_f, 1);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
const float token_f = queries_token[query_idx];
|
||||
layers_output(query_idx, queries_pos[query_idx], "tokens", -1, &token_f,
|
||||
1);
|
||||
}
|
||||
}
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||
const size_t query_idx = token_idx % num_queries;
|
||||
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos[query_idx], weights,
|
||||
activations.x);
|
||||
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
EmbedToken<TConfig>(queries_token[query_idx], query_idx,
|
||||
queries_pos[query_idx], weights, activations.x);
|
||||
}
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig>(num_tokens, num_queries, pos, layer,
|
||||
TransformerLayer<TConfig>(queries_pos, /*num_tokens=*/1, layer,
|
||||
layer_weights, activations, div_seq_len,
|
||||
kv_caches, pool);
|
||||
|
||||
if (layers_output) {
|
||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||
const size_t query_idx = token_idx % num_queries;
|
||||
const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries;
|
||||
layers_output(token_idx % num_queries, logical_pos, "blocks", layer,
|
||||
activations.x.Batch(token_idx), kModelDim);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
layers_output(query_idx, queries_pos[query_idx], "blocks", layer,
|
||||
activations.x.Batch(0), kModelDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RMSNormInplaceBatched(num_interleaved, weights.final_norm_scale.data_scale1(),
|
||||
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(),
|
||||
activations.x.All(), kModelDim);
|
||||
|
||||
if (layers_output) {
|
||||
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
||||
const size_t query_idx = token_idx % num_queries;
|
||||
const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries;
|
||||
layers_output(query_idx, logical_pos, "final_norm", -1,
|
||||
activations.x.Batch(token_idx), kModelDim);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
layers_output(query_idx, queries_pos[query_idx], "final_norm", -1,
|
||||
activations.x.Batch(0), kModelDim);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
queries_pos[query_idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
|
|
@ -848,32 +835,16 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
|||
|
||||
// Placeholder for internal test3, do not remove
|
||||
|
||||
// Returns interleaved tokens: one from each query, followed by the second from
|
||||
// all queries, with EOS padding.
|
||||
static std::vector<int> InterleaveQueries(const MultiplePromptsTokens& queries,
|
||||
const RuntimeConfig& runtime_config,
|
||||
size_t& min_prompt_size,
|
||||
size_t& max_prompt_size) {
|
||||
const size_t num_queries = queries.size();
|
||||
// Returns the min and max number of tokens for all queries.
|
||||
static void ScanQueryLengths(const QueriesPromptTokens& queries_prompt,
|
||||
size_t& min_prompt_size, size_t& max_prompt_size) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
min_prompt_size = hwy::LimitsMax<size_t>();
|
||||
max_prompt_size = 0;
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
min_prompt_size = std::min(min_prompt_size, queries[i].size());
|
||||
max_prompt_size = std::max(max_prompt_size, queries[i].size());
|
||||
min_prompt_size = std::min(min_prompt_size, queries_prompt[i].size());
|
||||
max_prompt_size = std::max(max_prompt_size, queries_prompt[i].size());
|
||||
}
|
||||
|
||||
std::vector<int> prompt;
|
||||
prompt.reserve(max_prompt_size * num_queries);
|
||||
for (size_t pos = 0; pos < max_prompt_size; ++pos) {
|
||||
for (size_t q = 0; q < num_queries; ++q) {
|
||||
if (pos < queries[q].size()) {
|
||||
prompt.push_back(queries[q][pos]);
|
||||
} else {
|
||||
prompt.push_back(runtime_config.eos_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
return prompt;
|
||||
}
|
||||
|
||||
// Holds "is at end of stream" state for each query.
|
||||
|
|
@ -901,21 +872,22 @@ class TokenStreamer {
|
|||
hwy::BitSet4096<> is_eos_;
|
||||
};
|
||||
|
||||
// Generates one token for each query in `prompts`, which is one qbatch whose
|
||||
// size is at most the `batch_size` passed to `activations.Allocate`.
|
||||
// Generates one continuation for each query in `queries_prompt`, which is one
|
||||
// qbatch whose size is at most the `batch_size` passed to
|
||||
// `activations.Allocate`.
|
||||
//
|
||||
// `pos` indexes the KV cache. In the first turn of a chat, pos = 0, and it
|
||||
// continues to increase by one for each prefilled/generated token per query.
|
||||
// `queries_pos` stores the KV cache position for each query. In the first turn
|
||||
// of a chat, pos = 0; we increment each query's position after each token.
|
||||
//
|
||||
// `query_idx_start` is the query_idx of the first query in the batch, so that
|
||||
// `StreamFunc` gets the global query index, not relative to the batch.
|
||||
//
|
||||
// `kv_caches` is for the batch, size must match `prompts`.
|
||||
// `kv_caches` is for the batch, size must match `queries_prompt`.
|
||||
template <class TConfig>
|
||||
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const MultiplePromptsTokens& prompts,
|
||||
const MultiplePositions& pos, const size_t query_idx_start,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos_in, const size_t query_idx_start,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
|
|
@ -926,22 +898,28 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// TODO: remove once all parallel sections support hierarchical parallelism.
|
||||
hwy::ThreadPool& pool = pools.Inner(0);
|
||||
|
||||
const size_t num_queries = prompts.size();
|
||||
// Copy so we can increment without requiring users to pass in a mutable span.
|
||||
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
||||
queries_pos_in.cend());
|
||||
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
||||
queries_pos_copy.size());
|
||||
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
|
||||
HWY_ASSERT(num_queries <= activations.x.BatchSize());
|
||||
HWY_ASSERT(queries_pos_in.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||
|
||||
size_t min_prompt_size, max_prompt_size;
|
||||
const std::vector<int> prompt = InterleaveQueries(
|
||||
prompts, runtime_config, min_prompt_size, max_prompt_size);
|
||||
ScanQueryLengths(queries_prompt, min_prompt_size, max_prompt_size);
|
||||
|
||||
size_t max_tokens = runtime_config.max_tokens;
|
||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
|
||||
for (auto i = pos.cbegin(); i != pos.cend(); ++i) {
|
||||
if (*i >= max_tokens) {
|
||||
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", *i,
|
||||
for (size_t pos : queries_pos_copy) {
|
||||
if (pos >= max_tokens) {
|
||||
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
|
||||
max_tokens);
|
||||
return;
|
||||
}
|
||||
|
|
@ -967,17 +945,13 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
prefill.Init<TConfig>(num_queries, runtime_config.prefill_tbatch_size,
|
||||
pools);
|
||||
prefill_start = hwy::platform::Now();
|
||||
prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
|
||||
weights, runtime_config, div_seq_len, kv_caches,
|
||||
pools);
|
||||
prefill.Prefill<TConfig>(queries_prompt, prefill_per_query,
|
||||
queries_mutable_pos, query_idx_start, weights,
|
||||
runtime_config, div_seq_len, kv_caches, pools);
|
||||
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
||||
// queries_pos are incremented by Prefill.
|
||||
}
|
||||
|
||||
std::vector<size_t> interleaved_pos(pos.size());
|
||||
std::transform(
|
||||
pos.cbegin(), pos.cend(), interleaved_pos.begin(),
|
||||
[&](size_t v) { return (v + prefill_per_query) * num_queries; });
|
||||
|
||||
// Storage for the last generated token from each query, passed to the next
|
||||
// Transformer() call.
|
||||
std::vector<int> gen_tokens(num_queries);
|
||||
|
|
@ -985,25 +959,19 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// Stream the last prompt token from each query and fill gen_tokens.
|
||||
TokenStreamer token_streamer(runtime_config);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
|
||||
gen_tokens[query_idx] = queries_prompt[query_idx][prefill_per_query];
|
||||
(void)token_streamer(query_idx_start + query_idx,
|
||||
pos[query_idx] + prefill_per_query,
|
||||
gen_tokens[query_idx], 0.0f);
|
||||
queries_mutable_pos[query_idx], gen_tokens[query_idx],
|
||||
0.0f);
|
||||
}
|
||||
|
||||
const double gen_start = hwy::platform::Now();
|
||||
for (size_t gen_per_query = 0;
|
||||
gen_per_query < HWY_MIN(max_tokens, max_generated_tokens);
|
||||
++gen_per_query) {
|
||||
// Decode: generate one token for each query.
|
||||
Transformer<TConfig>(
|
||||
gen_tokens.data(), /*num_tokens=*/1, num_queries,
|
||||
MultiplePositions(interleaved_pos.data(), interleaved_pos.size()),
|
||||
weights, activations, div_seq_len, kv_caches, pool,
|
||||
runtime_config.layers_output);
|
||||
for (auto& v : interleaved_pos) {
|
||||
v += num_queries;
|
||||
}
|
||||
for (size_t gen = 0; gen < HWY_MIN(max_tokens, max_generated_tokens); ++gen) {
|
||||
// Decode generates one token per query and increments queries_mutable_pos.
|
||||
Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries),
|
||||
queries_mutable_pos, weights, activations, div_seq_len,
|
||||
kv_caches, pool, runtime_config.layers_output);
|
||||
// queries_pos are incremented by Transformer.
|
||||
|
||||
bool all_queries_eos = true;
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
|
|
@ -1022,8 +990,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
|
||||
const bool is_eos =
|
||||
token_streamer(query_idx_start + query_idx,
|
||||
pos[query_idx] + prefill_per_query + 1 + gen_per_query,
|
||||
token, logits[token]);
|
||||
queries_mutable_pos[query_idx], token, logits[token]);
|
||||
all_queries_eos &= is_eos;
|
||||
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;
|
||||
}
|
||||
|
|
@ -1038,28 +1005,29 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
|
|||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||
PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
const size_t num_queries = 1;
|
||||
constexpr size_t kNumQueries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
|
||||
Activations activations;
|
||||
activations.Allocate<TConfig>(num_queries);
|
||||
activations.Allocate<TConfig>(kNumQueries);
|
||||
|
||||
const MultiplePromptsTokens prompts(&prompt, num_queries);
|
||||
const MultiplePositions positions(&pos, num_queries);
|
||||
const KVCaches kv_caches{&kv_cache, num_queries};
|
||||
const QueriesPromptTokens prompt_span(&prompt, kNumQueries);
|
||||
QueriesPos pos_span(&pos, kNumQueries);
|
||||
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
||||
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts,
|
||||
positions, qbatch_start, kv_caches, pools, timing_info);
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompt_span,
|
||||
pos_span, qbatch_start, kv_caches, pools, timing_info);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateBatchT(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const MultiplePromptsTokens& prompts,
|
||||
const MultiplePositions& pos, const KVCaches& kv_caches,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
||||
PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
HWY_ASSERT(prompts.size() == pos.size() &&
|
||||
prompts.size() == kv_caches.size());
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
// Griffin does not support query batching.
|
||||
const size_t max_qbatch_size =
|
||||
(TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size;
|
||||
|
|
@ -1067,15 +1035,14 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
|||
Activations activations;
|
||||
activations.Allocate<TConfig>(max_qbatch_size);
|
||||
|
||||
const size_t num_queries = prompts.size();
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
qbatch_start += max_qbatch_size) {
|
||||
// Generate one batch of tokens from `qbatch_size` queries.
|
||||
const size_t qbatch_size =
|
||||
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
||||
const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start],
|
||||
qbatch_size);
|
||||
const MultiplePositions qbatch_pos(&pos[qbatch_start], qbatch_size);
|
||||
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
|
||||
qbatch_size);
|
||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
|
||||
qbatch_pos, qbatch_start, qbatch_kv, pools, timing_info);
|
||||
|
|
@ -1098,11 +1065,13 @@ void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
|||
|
||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
|
||||
const MultiplePositions& pos, const KVCaches& kv_caches,
|
||||
PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
||||
(weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info);
|
||||
(weights_u8, runtime_config, queries_prompt, queries_pos, kv_caches, pools,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -64,12 +64,11 @@ Gemma::~Gemma() {
|
|||
const PromptTokens& prompt, size_t pos, \
|
||||
KVCache& kv_cache, PerClusterPools& pools, \
|
||||
TimingInfo& timing_info); \
|
||||
extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const MultiplePromptsTokens& prompts, \
|
||||
const MultiplePositions& pos, \
|
||||
const KVCaches& kv_caches, PerClusterPools& pools, \
|
||||
TimingInfo& timing_info);
|
||||
extern void GenerateBatch( \
|
||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches, \
|
||||
PerClusterPools& pools, TimingInfo& timing_info);
|
||||
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
||||
|
||||
// Adapters to select from the above overloads via CallForModelAndWeight.
|
||||
|
|
@ -88,35 +87,35 @@ template <class TConfig>
|
|||
struct GenerateBatchT {
|
||||
void operator()(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const MultiplePromptsTokens& prompts,
|
||||
const MultiplePositions& pos, const KVCaches& kv_caches,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
||||
PerClusterPools& pools, TimingInfo& timing_info) const {
|
||||
GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos,
|
||||
kv_caches, pools, timing_info);
|
||||
GenerateBatch(TConfig(), weights_u8, runtime_config, queries_prompt,
|
||||
queries_pos, kv_caches, pools, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||
TimingInfo& timing_info) {
|
||||
pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
|
||||
runtime_config, prompt, start_pos,
|
||||
kv_cache, pools_, timing_info);
|
||||
runtime_config, prompt, pos, kv_cache,
|
||||
pools_, timing_info);
|
||||
|
||||
pools_.StopSpinning();
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const MultiplePromptsTokens& prompts,
|
||||
const MultiplePositions& start_pos,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info) {
|
||||
pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_,
|
||||
runtime_config, prompts, start_pos,
|
||||
kv_caches, pools_, timing_info);
|
||||
CallForModelAndWeight<GenerateBatchT>(
|
||||
info_.model, info_.weight, weights_u8_, runtime_config, queries_prompt,
|
||||
queries_pos, kv_caches, pools_, timing_info);
|
||||
|
||||
pools_.StopSpinning();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -142,8 +142,12 @@ struct TimingInfo {
|
|||
};
|
||||
|
||||
using PromptTokens = hwy::Span<const int>;
|
||||
using MultiplePromptsTokens = hwy::Span<const PromptTokens>;
|
||||
using MultiplePositions = hwy::Span<const size_t>;
|
||||
|
||||
// Batches of independent queries have their own prompt, previous token,
|
||||
// position in the sequence, and KVCache.
|
||||
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
||||
using QueriesToken = hwy::Span<const int>;
|
||||
using QueriesPos = hwy::Span<const size_t>;
|
||||
using KVCaches = hwy::Span<KVCache>;
|
||||
|
||||
class Gemma {
|
||||
|
|
@ -161,13 +165,17 @@ class Gemma {
|
|||
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||
ByteStorageT& MutableWeights() { return weights_u8_; }
|
||||
|
||||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info);
|
||||
size_t pos, KVCache& kv_cache, TimingInfo& timing_info);
|
||||
|
||||
// `queries_pos` are the positions in the KV cache. Users are responsible for
|
||||
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const MultiplePromptsTokens& prompts,
|
||||
const MultiplePositions& start_pos,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info);
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info);
|
||||
|
||||
private:
|
||||
PerClusterPools& pools_;
|
||||
|
|
|
|||
17
gemma/run.cc
17
gemma/run.cc
|
|
@ -80,20 +80,19 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
int verbosity, const AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
int prompt_size{};
|
||||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
|
||||
std::mt19937 gen;
|
||||
InitGenerator(args, gen);
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size,
|
||||
&model, verbosity](int token, float) {
|
||||
auto stream_token = [&](int token, float) {
|
||||
++abs_pos;
|
||||
++current_pos;
|
||||
++tokens_generated_this_turn;
|
||||
// <= since position is incremented before
|
||||
if (current_pos <= prompt_size) {
|
||||
if (tokens_generated_this_turn <= prompt_size) {
|
||||
std::cerr << "." << std::flush;
|
||||
} else if (token == EOS_ID) {
|
||||
if (!args.multiturn) {
|
||||
|
|
@ -108,7 +107,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
HWY_ASSERT(
|
||||
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
// +1 since position is incremented above
|
||||
if (current_pos == prompt_size + 1) {
|
||||
if (tokens_generated_this_turn == prompt_size + 1) {
|
||||
// first token of response
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (verbosity >= 1) {
|
||||
|
|
@ -121,7 +120,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
};
|
||||
|
||||
while (abs_pos < args.max_tokens) {
|
||||
current_pos = 0;
|
||||
tokens_generated_this_turn = 0;
|
||||
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
|
||||
if (!std::cin) return;
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
|
|
|
|||
Loading…
Reference in New Issue