|
|
|
@ -233,7 +233,7 @@ class GemmaAttention {
|
|
|
|
// Fills activations.q and computes KV. For kIsMHA, a single MatMul suffices
|
|
|
|
// Fills activations.q and computes KV. For kIsMHA, a single MatMul suffices
|
|
|
|
// and we later copy KV from q to KVCache. Otherwise, a second MatMul writes
|
|
|
|
// and we later copy KV from q to KVCache. Otherwise, a second MatMul writes
|
|
|
|
// KV directly to KVCache.
|
|
|
|
// KV directly to KVCache.
|
|
|
|
HWY_NOINLINE void ComputeQKV(const size_t batch_start,
|
|
|
|
HWY_NOINLINE void ComputeQKV(const MultiplePositions& batch_start,
|
|
|
|
const size_t num_interleaved) {
|
|
|
|
const size_t num_interleaved) {
|
|
|
|
PROFILER_ZONE("Gen.Attention.QKV");
|
|
|
|
PROFILER_ZONE("Gen.Attention.QKV");
|
|
|
|
// For the computation of Q, K, and V, it is useful to remember that
|
|
|
|
// For the computation of Q, K, and V, it is useful to remember that
|
|
|
|
@ -255,9 +255,9 @@ class GemmaAttention {
|
|
|
|
// Single query and no wraparound means we can use a matmul and write
|
|
|
|
// Single query and no wraparound means we can use a matmul and write
|
|
|
|
// directly into the KV cache with a stride of kCachePosSize.
|
|
|
|
// directly into the KV cache with a stride of kCachePosSize.
|
|
|
|
if (num_queries_ == 1 &&
|
|
|
|
if (num_queries_ == 1 &&
|
|
|
|
batch_start + num_tokens_ <= div_seq_len_.GetDivisor()) {
|
|
|
|
batch_start[0] + num_tokens_ <= div_seq_len_.GetDivisor()) {
|
|
|
|
const size_t kv_ofs =
|
|
|
|
const size_t kv_ofs =
|
|
|
|
batch_start * kCachePosSize + layer_ * kCacheLayerSize;
|
|
|
|
batch_start[0] * kCachePosSize + layer_ * kCacheLayerSize;
|
|
|
|
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
|
|
|
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
|
|
|
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
|
|
|
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
|
|
|
MatMul_4x4</*kAdd=*/false>(
|
|
|
|
MatMul_4x4</*kAdd=*/false>(
|
|
|
|
@ -275,7 +275,7 @@ class GemmaAttention {
|
|
|
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
|
|
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
|
|
|
KVCache& kv_cache = kv_caches_[query_idx];
|
|
|
|
KVCache& kv_cache = kv_caches_[query_idx];
|
|
|
|
const size_t cache_pos =
|
|
|
|
const size_t cache_pos =
|
|
|
|
div_seq_len_.Remainder(batch_start + batch_idx);
|
|
|
|
div_seq_len_.Remainder(batch_start[query_idx] + batch_idx);
|
|
|
|
const size_t kv_offset =
|
|
|
|
const size_t kv_offset =
|
|
|
|
cache_pos * kCachePosSize + layer_ * kCacheLayerSize;
|
|
|
|
cache_pos * kCachePosSize + layer_ * kCacheLayerSize;
|
|
|
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
|
|
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
|
|
|
@ -295,7 +295,7 @@ class GemmaAttention {
|
|
|
|
const size_t interleaved_idx = task / kKVHeads;
|
|
|
|
const size_t interleaved_idx = task / kKVHeads;
|
|
|
|
const size_t query_idx = interleaved_idx % num_queries_;
|
|
|
|
const size_t query_idx = interleaved_idx % num_queries_;
|
|
|
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
|
|
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
|
|
|
const size_t pos = batch_start + batch_idx;
|
|
|
|
const size_t pos = batch_start[query_idx] + batch_idx;
|
|
|
|
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
|
|
|
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
|
|
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
|
|
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
|
|
|
layer_ * kCacheLayerSize +
|
|
|
|
layer_ * kCacheLayerSize +
|
|
|
|
@ -374,7 +374,7 @@ class GemmaAttention {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t batch_start,
|
|
|
|
HWY_NOINLINE void DotSoftmaxWeightedSum(const MultiplePositions& batch_start,
|
|
|
|
const size_t num_interleaved) {
|
|
|
|
const size_t num_interleaved) {
|
|
|
|
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
|
|
|
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
|
|
|
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
|
|
|
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
|
|
|
@ -398,7 +398,7 @@ class GemmaAttention {
|
|
|
|
activations_.q.Batch(interleaved_idx) + head * kQStride;
|
|
|
|
activations_.q.Batch(interleaved_idx) + head * kQStride;
|
|
|
|
|
|
|
|
|
|
|
|
// Apply rope and scaling to Q.
|
|
|
|
// Apply rope and scaling to Q.
|
|
|
|
const size_t pos = batch_start + batch_idx;
|
|
|
|
const size_t pos = batch_start[query_idx] + batch_idx;
|
|
|
|
PositionalEncodingQK(q, pos, layer_, kQueryScale, q);
|
|
|
|
PositionalEncodingQK(q, pos, layer_, kQueryScale, q);
|
|
|
|
|
|
|
|
|
|
|
|
const size_t start_pos = StartPos(pos, layer_);
|
|
|
|
const size_t start_pos = StartPos(pos, layer_);
|
|
|
|
@ -440,13 +440,12 @@ class GemmaAttention {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|
|
|
GemmaAttention(const MultiplePositions& interleaved_start, size_t num_tokens,
|
|
|
|
size_t num_queries, size_t layer, Activations& activations,
|
|
|
|
size_t num_queries, size_t layer, Activations& activations,
|
|
|
|
const CompressedLayer<TConfig>* layer_weights,
|
|
|
|
const CompressedLayer<TConfig>* layer_weights,
|
|
|
|
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
|
|
|
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
|
|
|
hwy::ThreadPool& pool)
|
|
|
|
hwy::ThreadPool& pool)
|
|
|
|
: interleaved_start_(interleaved_start),
|
|
|
|
: num_tokens_(num_tokens),
|
|
|
|
num_tokens_(num_tokens),
|
|
|
|
|
|
|
|
num_queries_(num_queries),
|
|
|
|
num_queries_(num_queries),
|
|
|
|
layer_(layer),
|
|
|
|
layer_(layer),
|
|
|
|
activations_(activations),
|
|
|
|
activations_(activations),
|
|
|
|
@ -454,12 +453,21 @@ class GemmaAttention {
|
|
|
|
div_seq_len_(div_seq_len),
|
|
|
|
div_seq_len_(div_seq_len),
|
|
|
|
kv_caches_(kv_caches),
|
|
|
|
kv_caches_(kv_caches),
|
|
|
|
pool_(pool) {
|
|
|
|
pool_(pool) {
|
|
|
|
HWY_DASSERT(interleaved_start_ % num_queries_ == 0);
|
|
|
|
HWY_DASSERT(
|
|
|
|
|
|
|
|
std::all_of(interleaved_start.cbegin(), interleaved_start.cend(),
|
|
|
|
|
|
|
|
[this](size_t pos) { return pos % num_queries_ == 0; }));
|
|
|
|
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
|
|
|
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_start_.reserve(interleaved_start.size());
|
|
|
|
|
|
|
|
for (auto i = interleaved_start.cbegin(); i != interleaved_start.cend();
|
|
|
|
|
|
|
|
++i) {
|
|
|
|
|
|
|
|
batch_start_.emplace_back(*i / num_queries_);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
HWY_INLINE void operator()() {
|
|
|
|
HWY_INLINE void operator()() {
|
|
|
|
const size_t batch_start = interleaved_start_ / num_queries_;
|
|
|
|
const MultiplePositions batch_start(batch_start_.data(),
|
|
|
|
|
|
|
|
batch_start_.size());
|
|
|
|
const size_t num_interleaved = num_tokens_ * num_queries_;
|
|
|
|
const size_t num_interleaved = num_tokens_ * num_queries_;
|
|
|
|
|
|
|
|
|
|
|
|
ComputeQKV(batch_start, num_interleaved);
|
|
|
|
ComputeQKV(batch_start, num_interleaved);
|
|
|
|
@ -468,7 +476,7 @@ class GemmaAttention {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
const size_t interleaved_start_;
|
|
|
|
std::vector<size_t> batch_start_;
|
|
|
|
const size_t num_tokens_;
|
|
|
|
const size_t num_tokens_;
|
|
|
|
const size_t num_queries_;
|
|
|
|
const size_t num_queries_;
|
|
|
|
const size_t layer_;
|
|
|
|
const size_t layer_;
|
|
|
|
@ -480,7 +488,8 @@ class GemmaAttention {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <class TConfig>
|
|
|
|
template <class TConfig>
|
|
|
|
HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
|
|
|
|
HWY_NOINLINE void Attention(LayerAttentionType type,
|
|
|
|
|
|
|
|
const MultiplePositions& interleaved_start,
|
|
|
|
size_t num_tokens, size_t num_queries, size_t layer,
|
|
|
|
size_t num_tokens, size_t num_queries, size_t layer,
|
|
|
|
Activations& activations,
|
|
|
|
Activations& activations,
|
|
|
|
const CompressedLayer<TConfig>* layer_weights,
|
|
|
|
const CompressedLayer<TConfig>* layer_weights,
|
|
|
|
@ -495,7 +504,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
|
|
|
|
// this code for non-Griffin models.
|
|
|
|
// this code for non-Griffin models.
|
|
|
|
if constexpr (TConfig::kGriffinLayers > 0) {
|
|
|
|
if constexpr (TConfig::kGriffinLayers > 0) {
|
|
|
|
HWY_ASSERT(num_queries == 1);
|
|
|
|
HWY_ASSERT(num_queries == 1);
|
|
|
|
GriffinRecurrent<TConfig>(interleaved_start, num_tokens, num_queries,
|
|
|
|
GriffinRecurrent<TConfig>(interleaved_start[0], num_tokens, num_queries,
|
|
|
|
layer, activations, layer_weights, kv_caches,
|
|
|
|
layer, activations, layer_weights, kv_caches,
|
|
|
|
pool);
|
|
|
|
pool);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -599,10 +608,10 @@ void PostNorm(size_t num_interleaved, const WeightT& weights, InOutT* inout) {
|
|
|
|
|
|
|
|
|
|
|
|
template <class TConfig>
|
|
|
|
template <class TConfig>
|
|
|
|
HWY_NOINLINE void TransformerLayer(
|
|
|
|
HWY_NOINLINE void TransformerLayer(
|
|
|
|
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
|
|
|
size_t num_tokens, size_t num_queries, const MultiplePositions& pos,
|
|
|
|
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
|
|
|
size_t layer, const CompressedLayer<TConfig>* layer_weights,
|
|
|
|
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
|
|
|
Activations& activations, const hwy::Divisor& div_seq_len,
|
|
|
|
hwy::ThreadPool& pool) {
|
|
|
|
const KVCaches& kv_caches, hwy::ThreadPool& pool) {
|
|
|
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
|
|
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
|
|
|
const size_t num_interleaved = num_tokens * num_queries;
|
|
|
|
const size_t num_interleaved = num_tokens * num_queries;
|
|
|
|
auto type = TConfig::kLayerConfig[layer];
|
|
|
|
auto type = TConfig::kLayerConfig[layer];
|
|
|
|
@ -688,7 +697,8 @@ class PrefillState {
|
|
|
|
|
|
|
|
|
|
|
|
template <class TConfig>
|
|
|
|
template <class TConfig>
|
|
|
|
HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts,
|
|
|
|
HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts,
|
|
|
|
const size_t prefill_per_query, const size_t pos,
|
|
|
|
const size_t prefill_per_query,
|
|
|
|
|
|
|
|
const MultiplePositions& pos,
|
|
|
|
const size_t query_idx_start,
|
|
|
|
const size_t query_idx_start,
|
|
|
|
const CompressedWeights<TConfig>& weights,
|
|
|
|
const CompressedWeights<TConfig>& weights,
|
|
|
|
const RuntimeConfig& runtime_config,
|
|
|
|
const RuntimeConfig& runtime_config,
|
|
|
|
@ -719,14 +729,19 @@ class PrefillState {
|
|
|
|
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
|
|
|
|
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
|
|
|
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
|
|
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
|
|
|
const int token = prompts[qi][tbatch_start + ti];
|
|
|
|
const int token = prompts[qi][tbatch_start + ti];
|
|
|
|
EmbedToken<TConfig>(token, ti, pos + ti, weights, activations.x);
|
|
|
|
EmbedToken<TConfig>(token, ti, pos[qi] + ti, weights,
|
|
|
|
|
|
|
|
activations.x);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const size_t tbatch_pos = pos[qi] + tbatch_start;
|
|
|
|
|
|
|
|
const MultiplePositions prefill_tbatch_pos(&tbatch_pos,
|
|
|
|
|
|
|
|
kPrefillQueries);
|
|
|
|
|
|
|
|
|
|
|
|
// Transformer with one batch of tokens from a single query.
|
|
|
|
// Transformer with one batch of tokens from a single query.
|
|
|
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
|
|
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
|
|
|
const auto* layer_weights = weights.GetLayer(layer);
|
|
|
|
const auto* layer_weights = weights.GetLayer(layer);
|
|
|
|
TransformerLayer<TConfig>(tbatch_size, kPrefillQueries,
|
|
|
|
TransformerLayer<TConfig>(tbatch_size, kPrefillQueries,
|
|
|
|
pos + tbatch_start, layer,
|
|
|
|
prefill_tbatch_pos, layer,
|
|
|
|
layer_weights, activations, div_seq_len,
|
|
|
|
layer_weights, activations, div_seq_len,
|
|
|
|
prefill_kv_caches, inner_pool);
|
|
|
|
prefill_kv_caches, inner_pool);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -734,8 +749,8 @@ class PrefillState {
|
|
|
|
// NOTE: we unconditionally call StreamToken, even if EOS.
|
|
|
|
// NOTE: we unconditionally call StreamToken, even if EOS.
|
|
|
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
|
|
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
|
|
|
const int token = prompts[qi][tbatch_start + ti];
|
|
|
|
const int token = prompts[qi][tbatch_start + ti];
|
|
|
|
runtime_config.StreamToken(query_idx_start + qi,
|
|
|
|
runtime_config.StreamToken(query_idx_start + qi, tbatch_pos + ti,
|
|
|
|
pos + tbatch_start + ti, token, 0.0f);
|
|
|
|
token, 0.0f);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // for tbatch_start
|
|
|
|
} // for tbatch_start
|
|
|
|
});
|
|
|
|
});
|
|
|
|
@ -749,7 +764,7 @@ class PrefillState {
|
|
|
|
// `num_tokens == 1`.
|
|
|
|
// `num_tokens == 1`.
|
|
|
|
template <class TConfig>
|
|
|
|
template <class TConfig>
|
|
|
|
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
|
|
|
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
|
|
|
size_t num_queries, size_t pos,
|
|
|
|
size_t num_queries, const MultiplePositions& pos,
|
|
|
|
const CompressedWeights<TConfig>& weights,
|
|
|
|
const CompressedWeights<TConfig>& weights,
|
|
|
|
Activations& activations,
|
|
|
|
Activations& activations,
|
|
|
|
const hwy::Divisor& div_seq_len,
|
|
|
|
const hwy::Divisor& div_seq_len,
|
|
|
|
@ -759,14 +774,15 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
|
|
|
if (layers_output) {
|
|
|
|
if (layers_output) {
|
|
|
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
|
|
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
|
|
|
const size_t query_idx = token_idx % num_queries;
|
|
|
|
const size_t query_idx = token_idx % num_queries;
|
|
|
|
const size_t logical_pos = (pos + token_idx) / num_queries;
|
|
|
|
const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries;
|
|
|
|
const float token_f = tokens[token_idx];
|
|
|
|
const float token_f = tokens[token_idx];
|
|
|
|
layers_output(query_idx, logical_pos, "tokens", -1, &token_f, 1);
|
|
|
|
layers_output(query_idx, logical_pos, "tokens", -1, &token_f, 1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
|
|
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
|
|
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
|
|
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
|
|
|
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
|
|
|
|
const size_t query_idx = token_idx % num_queries;
|
|
|
|
|
|
|
|
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos[query_idx], weights,
|
|
|
|
activations.x);
|
|
|
|
activations.x);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -778,7 +794,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
if (layers_output) {
|
|
|
|
if (layers_output) {
|
|
|
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
|
|
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
|
|
|
const size_t logical_pos = (pos + token_idx) / num_queries;
|
|
|
|
const size_t query_idx = token_idx % num_queries;
|
|
|
|
|
|
|
|
const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries;
|
|
|
|
layers_output(token_idx % num_queries, logical_pos, "blocks", layer,
|
|
|
|
layers_output(token_idx % num_queries, logical_pos, "blocks", layer,
|
|
|
|
activations.x.Batch(token_idx), kModelDim);
|
|
|
|
activations.x.Batch(token_idx), kModelDim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -790,7 +807,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
|
|
|
if (layers_output) {
|
|
|
|
if (layers_output) {
|
|
|
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
|
|
|
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
|
|
|
|
const size_t query_idx = token_idx % num_queries;
|
|
|
|
const size_t query_idx = token_idx % num_queries;
|
|
|
|
const size_t logical_pos = (pos + token_idx) / num_queries;
|
|
|
|
const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries;
|
|
|
|
layers_output(query_idx, logical_pos, "final_norm", -1,
|
|
|
|
layers_output(query_idx, logical_pos, "final_norm", -1,
|
|
|
|
activations.x.Batch(token_idx), kModelDim);
|
|
|
|
activations.x.Batch(token_idx), kModelDim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -897,9 +914,10 @@ class TokenStreamer {
|
|
|
|
template <class TConfig>
|
|
|
|
template <class TConfig>
|
|
|
|
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|
|
|
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|
|
|
const RuntimeConfig& runtime_config,
|
|
|
|
const RuntimeConfig& runtime_config,
|
|
|
|
const MultiplePromptsTokens& prompts, const size_t pos,
|
|
|
|
const MultiplePromptsTokens& prompts,
|
|
|
|
const size_t query_idx_start, const KVCaches& kv_caches,
|
|
|
|
const MultiplePositions& pos, const size_t query_idx_start,
|
|
|
|
PerClusterPools& pools, TimingInfo& timing_info) {
|
|
|
|
const KVCaches& kv_caches, PerClusterPools& pools,
|
|
|
|
|
|
|
|
TimingInfo& timing_info) {
|
|
|
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
|
|
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
|
|
|
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
|
|
|
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
|
|
|
const CompressedWeights<TConfig>& weights =
|
|
|
|
const CompressedWeights<TConfig>& weights =
|
|
|
|
@ -921,10 +939,12 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|
|
|
size_t max_tokens = runtime_config.max_tokens;
|
|
|
|
size_t max_tokens = runtime_config.max_tokens;
|
|
|
|
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
|
|
|
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
|
|
|
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
|
|
|
|
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
|
|
|
|
if (pos >= max_tokens) {
|
|
|
|
for (auto i = pos.cbegin(); i != pos.cend(); ++i) {
|
|
|
|
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
|
|
|
|
if (*i >= max_tokens) {
|
|
|
|
max_tokens);
|
|
|
|
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", *i,
|
|
|
|
return;
|
|
|
|
max_tokens);
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// If no sample_func is provided, we use top-k sampling.
|
|
|
|
// If no sample_func is provided, we use top-k sampling.
|
|
|
|
@ -953,7 +973,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|
|
|
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
|
|
|
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
size_t interleaved_pos = (pos + prefill_per_query) * num_queries;
|
|
|
|
std::vector<size_t> interleaved_pos(pos.size());
|
|
|
|
|
|
|
|
std::transform(
|
|
|
|
|
|
|
|
pos.cbegin(), pos.cend(), interleaved_pos.begin(),
|
|
|
|
|
|
|
|
[&](size_t v) { return (v + prefill_per_query) * num_queries; });
|
|
|
|
|
|
|
|
|
|
|
|
// Storage for the last generated token from each query, passed to the next
|
|
|
|
// Storage for the last generated token from each query, passed to the next
|
|
|
|
// Transformer() call.
|
|
|
|
// Transformer() call.
|
|
|
|
@ -963,7 +986,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|
|
|
TokenStreamer token_streamer(runtime_config);
|
|
|
|
TokenStreamer token_streamer(runtime_config);
|
|
|
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
|
|
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
|
|
|
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
|
|
|
|
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
|
|
|
|
(void)token_streamer(query_idx_start + query_idx, prefill_per_query,
|
|
|
|
(void)token_streamer(query_idx_start + query_idx,
|
|
|
|
|
|
|
|
pos[query_idx] + prefill_per_query,
|
|
|
|
gen_tokens[query_idx], 0.0f);
|
|
|
|
gen_tokens[query_idx], 0.0f);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -972,10 +996,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|
|
|
gen_per_query < HWY_MIN(max_tokens, max_generated_tokens);
|
|
|
|
gen_per_query < HWY_MIN(max_tokens, max_generated_tokens);
|
|
|
|
++gen_per_query) {
|
|
|
|
++gen_per_query) {
|
|
|
|
// Decode: generate one token for each query.
|
|
|
|
// Decode: generate one token for each query.
|
|
|
|
Transformer<TConfig>(gen_tokens.data(), /*num_tokens=*/1, num_queries,
|
|
|
|
Transformer<TConfig>(
|
|
|
|
interleaved_pos, weights, activations, div_seq_len,
|
|
|
|
gen_tokens.data(), /*num_tokens=*/1, num_queries,
|
|
|
|
kv_caches, pool, runtime_config.layers_output);
|
|
|
|
MultiplePositions(interleaved_pos.data(), interleaved_pos.size()),
|
|
|
|
interleaved_pos += num_queries;
|
|
|
|
weights, activations, div_seq_len, kv_caches, pool,
|
|
|
|
|
|
|
|
runtime_config.layers_output);
|
|
|
|
|
|
|
|
for (auto& v : interleaved_pos) {
|
|
|
|
|
|
|
|
v += num_queries;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool all_queries_eos = true;
|
|
|
|
bool all_queries_eos = true;
|
|
|
|
PROFILER_ZONE("Gen.Embedding");
|
|
|
|
PROFILER_ZONE("Gen.Embedding");
|
|
|
|
@ -992,9 +1020,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|
|
|
const int token = sample_token(logits, kVocabSize);
|
|
|
|
const int token = sample_token(logits, kVocabSize);
|
|
|
|
timing_info.NotifyGenerated(prefill_start, gen_start);
|
|
|
|
timing_info.NotifyGenerated(prefill_start, gen_start);
|
|
|
|
|
|
|
|
|
|
|
|
const bool is_eos = token_streamer(query_idx_start + query_idx,
|
|
|
|
const bool is_eos =
|
|
|
|
prefill_per_query + 1 + gen_per_query,
|
|
|
|
token_streamer(query_idx_start + query_idx,
|
|
|
|
token, logits[token]);
|
|
|
|
pos[query_idx] + prefill_per_query + 1 + gen_per_query,
|
|
|
|
|
|
|
|
token, logits[token]);
|
|
|
|
all_queries_eos &= is_eos;
|
|
|
|
all_queries_eos &= is_eos;
|
|
|
|
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;
|
|
|
|
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -1016,19 +1045,21 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
|
|
|
|
activations.Allocate<TConfig>(num_queries);
|
|
|
|
activations.Allocate<TConfig>(num_queries);
|
|
|
|
|
|
|
|
|
|
|
|
const MultiplePromptsTokens prompts(&prompt, num_queries);
|
|
|
|
const MultiplePromptsTokens prompts(&prompt, num_queries);
|
|
|
|
|
|
|
|
const MultiplePositions positions(&pos, num_queries);
|
|
|
|
const KVCaches kv_caches{&kv_cache, num_queries};
|
|
|
|
const KVCaches kv_caches{&kv_cache, num_queries};
|
|
|
|
|
|
|
|
|
|
|
|
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts, pos,
|
|
|
|
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts,
|
|
|
|
qbatch_start, kv_caches, pools, timing_info);
|
|
|
|
positions, qbatch_start, kv_caches, pools, timing_info);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <class TConfig>
|
|
|
|
template <class TConfig>
|
|
|
|
void GenerateBatchT(const ByteStorageT& weights_u8,
|
|
|
|
void GenerateBatchT(const ByteStorageT& weights_u8,
|
|
|
|
const RuntimeConfig& runtime_config,
|
|
|
|
const RuntimeConfig& runtime_config,
|
|
|
|
const MultiplePromptsTokens& prompts, size_t pos,
|
|
|
|
const MultiplePromptsTokens& prompts,
|
|
|
|
const KVCaches& kv_caches, PerClusterPools& pools,
|
|
|
|
const MultiplePositions& pos, const KVCaches& kv_caches,
|
|
|
|
TimingInfo& timing_info) {
|
|
|
|
PerClusterPools& pools, TimingInfo& timing_info) {
|
|
|
|
HWY_ASSERT(prompts.size() == kv_caches.size());
|
|
|
|
HWY_ASSERT(prompts.size() == pos.size() &&
|
|
|
|
|
|
|
|
prompts.size() == kv_caches.size());
|
|
|
|
// Griffin does not support query batching.
|
|
|
|
// Griffin does not support query batching.
|
|
|
|
const size_t max_qbatch_size =
|
|
|
|
const size_t max_qbatch_size =
|
|
|
|
(TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size;
|
|
|
|
(TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size;
|
|
|
|
@ -1044,9 +1075,10 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
|
|
|
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
|
|
|
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
|
|
|
const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start],
|
|
|
|
const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start],
|
|
|
|
qbatch_size);
|
|
|
|
qbatch_size);
|
|
|
|
|
|
|
|
const MultiplePositions qbatch_pos(&pos[qbatch_start], qbatch_size);
|
|
|
|
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
|
|
|
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
|
|
|
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
|
|
|
|
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
|
|
|
|
pos, qbatch_start, qbatch_kv, pools, timing_info);
|
|
|
|
qbatch_pos, qbatch_start, qbatch_kv, pools, timing_info);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -1067,8 +1099,8 @@ void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
|
|
|
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
|
|
|
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
|
|
|
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
|
|
|
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
|
|
|
const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
|
|
|
|
const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
|
|
|
|
size_t pos, const KVCaches& kv_caches, PerClusterPools& pools,
|
|
|
|
const MultiplePositions& pos, const KVCaches& kv_caches,
|
|
|
|
TimingInfo& timing_info) {
|
|
|
|
PerClusterPools& pools, TimingInfo& timing_info) {
|
|
|
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
|
|
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
|
|
|
(weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info);
|
|
|
|
(weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|