mirror of https://github.com/google/gemma.cpp.git
Major Prefill/Generate cleanup, 1.3x Prefill speedup
This fixes TTFT, which was not including prefill. PiperOrigin-RevId: 653690626
This commit is contained in:
parent
3fe79b3876
commit
12016d31c3
|
|
@ -149,13 +149,14 @@ cc_library(
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
":kv_cache",
|
":kv_cache",
|
||||||
":weights",
|
":weights",
|
||||||
"//compression:compress",
|
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:bit_set",
|
||||||
"@hwy//:matvec",
|
"@hwy//:matvec",
|
||||||
"@hwy//:nanobenchmark", # timer
|
"@hwy//:nanobenchmark", # timer
|
||||||
"@hwy//:profiler",
|
"@hwy//:profiler",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
|
"@hwy//:topology",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,8 +76,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
fprintf(stderr, "Loading model...\n");
|
fprintf(stderr, "Loading model...\n");
|
||||||
model_ = AllocateGemma(loader_, pool_);
|
model_ = AllocateGemma(loader_, pool_);
|
||||||
|
|
||||||
kv_caches_.reserve(16);
|
kv_caches_.reserve(kBatchedQueryBatchSize);
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < kBatchedQueryBatchSize; ++i) {
|
||||||
kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model)));
|
kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,10 @@ ByteStorageT AllocateSizeof() {
|
||||||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr size_t kPrefillBatchSize = 512;
|
// Relatively small so that we can also parallelize non-Matmul work. There is
|
||||||
constexpr size_t kDecodeBatchSize = 1;
|
// one outer thread per batch, each with --num_threads / batches inner threads.
|
||||||
|
constexpr size_t kPrefillBatchSize = 64;
|
||||||
constexpr size_t kBatchedQueryBatchSize = 16;
|
constexpr size_t kBatchedQueryBatchSize = 16;
|
||||||
constexpr size_t kMinAdjustedPrefillBatchSize =
|
|
||||||
HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize);
|
|
||||||
|
|
||||||
// Model variants: see configs.h for details. When adding a new one, also
|
// Model variants: see configs.h for details. When adding a new one, also
|
||||||
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
|
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
|
||||||
|
|
|
||||||
|
|
@ -26,23 +26,26 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <string.h> // memcpy
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm> // std::min
|
||||||
|
#include <memory> // std::unique_ptr
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/ops.h"
|
#include "gemma/ops.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
// Placeholder for internal test4, do not remove
|
// Placeholder for internal test4, do not remove
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/bit_set.h"
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/contrib/thread_pool/topology.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
@ -269,7 +272,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
const float* HWY_RESTRICT q =
|
const float* HWY_RESTRICT q =
|
||||||
activations.q.Batch(batch_and_query_idx) + head * kQStride;
|
activations.q.Batch(batch_and_query_idx) + head * kQStride;
|
||||||
// Skip past the Q part of `q`, and copy KV to `kv`.
|
// Skip past the Q part of `q`, and copy KV to `kv`.
|
||||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float));
|
||||||
}
|
}
|
||||||
PostQK<TConfig>(kv, pos, layer);
|
PostQK<TConfig>(kv, pos, layer);
|
||||||
});
|
});
|
||||||
|
|
@ -414,6 +417,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: pass Activations.x instead of Activations.
|
// TODO: pass Activations.x instead of Activations.
|
||||||
|
// `pos` is for the entire batch and does not include `batch_idx`.
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
|
|
@ -421,8 +425,10 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
EmbeddingScaling<TConfig>();
|
EmbeddingScaling<TConfig>();
|
||||||
|
|
||||||
HWY_DASSERT(token >= 0);
|
HWY_DASSERT(token >= 0);
|
||||||
HWY_DASSERT(token < TConfig::kVocabSize);
|
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||||
|
|
||||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.Batch(batch_idx), kModelDim);
|
activations.x.Batch(batch_idx), kModelDim);
|
||||||
MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim);
|
MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim);
|
||||||
|
|
@ -441,7 +447,7 @@ HWY_NOINLINE void ResidualConnection(
|
||||||
AddFromBatched(num_tokens_and_queries, other, x, kModelDim);
|
AddFromBatched(num_tokens_and_queries, other, x, kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kQueryBatchSize>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void TransformerLayer(
|
HWY_NOINLINE void TransformerLayer(
|
||||||
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
||||||
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
||||||
|
|
@ -458,13 +464,10 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
|
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
|
||||||
layer_weights, kv_caches, pool);
|
layer_weights, kv_caches, pool);
|
||||||
} else {
|
} else {
|
||||||
// This Griffin layers should never exist unless the model is a Griffin
|
// Only reached if the model is Griffin. `if constexpr` prevents generating
|
||||||
// model. This conditional prevents the compiler from generating code for
|
// this code for non-Griffin models.
|
||||||
// this branch when building a non-Griffin model, since we have static
|
|
||||||
// asserts about the query batch size for Griffin layers.
|
|
||||||
if constexpr (TConfig::kGriffinLayers > 0) {
|
if constexpr (TConfig::kGriffinLayers > 0) {
|
||||||
static_assert(kQueryBatchSize == 1,
|
HWY_ASSERT(num_queries == 1);
|
||||||
"Griffin does not support batched queries.");
|
|
||||||
GriffinRecurrent<TConfig>(pos, num_tokens, num_queries, layer_of_type,
|
GriffinRecurrent<TConfig>(pos, num_tokens, num_queries, layer_of_type,
|
||||||
activations, layer_weights, kv_caches, pool);
|
activations, layer_weights, kv_caches, pool);
|
||||||
}
|
}
|
||||||
|
|
@ -494,39 +497,171 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
/*is_attention=*/false);
|
/*is_attention=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
// For prefill, we have two-level parallelism:
|
||||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens,
|
// - Outer: input tokens are split into batches, each of which is one task
|
||||||
size_t num_queries, size_t pos,
|
// processed by a worker in `outer_pool_`, which includes the main thread
|
||||||
const CompressedWeights<TConfig>& weights,
|
// because it is the one that calls `Prefill`.
|
||||||
Activations& activations,
|
// - Inner: each `outer` worker passes `inner_pools_[outer]` to
|
||||||
const std::vector<KVCache*>& kv_caches,
|
// `TransformerLayer` for tensor-level parallelism.
|
||||||
hwy::ThreadPool& pool) {
|
//
|
||||||
PROFILER_ZONE("Gen.Prefill");
|
// This class holds the thread pools and activations, recreated for each query.
|
||||||
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
//
|
||||||
const size_t minibatch_size = std::min(num_tokens, kBatchSize);
|
// It is safe to parallelize batches because we write to KVCache at
|
||||||
// TODO: hoist pool.Run out of the loop, change the unit of work to batches.
|
// `pos % kSeqLen`, which is far greater than the number of workers.
|
||||||
for (size_t i = 0; i < num_tokens; i += minibatch_size) {
|
// Note however that this currently leads to nondeterministic results because
|
||||||
const size_t offset = i * num_queries;
|
// the RNG is invoked in different order.
|
||||||
const size_t current_token_count = std::min(
|
class PrefillState {
|
||||||
minibatch_size, num_tokens - i);
|
public:
|
||||||
pool.Run(0, current_token_count * num_queries,
|
explicit PrefillState(hwy::ThreadPool& main_pool) : main_pool_(&main_pool) {}
|
||||||
[&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
|
||||||
EmbedToken<TConfig>(tokens[token_idx + offset], token_idx,
|
~PrefillState() { DeleteInnerPools(); }
|
||||||
pos + offset, weights, activations);
|
|
||||||
|
// Called before each query. Recreates thread pools, which has the advantage
|
||||||
|
// of tailoring the parallelism to the prompt length.
|
||||||
|
template <class TConfig>
|
||||||
|
void Init(size_t prefill_size) {
|
||||||
|
// Would be zero for single-token prompts (prefill_size == num_tokens - 1).
|
||||||
|
num_batches_ =
|
||||||
|
HWY_MAX(size_t{1}, hwy::DivCeil(prefill_size, kPrefillBatchSize));
|
||||||
|
// More than num_batches_ would waste workers on idling in the outer Run;
|
||||||
|
// more than NumWorkers() would exceed the global --num_threads.
|
||||||
|
const size_t outer_workers =
|
||||||
|
HWY_MIN(num_batches_, main_pool_->NumWorkers());
|
||||||
|
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
|
||||||
|
|
||||||
|
// One activation per outer worker. Allocating in parallel saves 30 ms.
|
||||||
|
activations_.resize(outer_workers);
|
||||||
|
main_pool_->Run(0, outer_workers, [this](uint64_t task, size_t /*thread*/) {
|
||||||
|
activations_[task].Allocate<TConfig>(kPrefillBatchSize);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
DeleteInnerPools();
|
||||||
|
|
||||||
|
// If we'd create just one inner pool with all the workers, skip the cost of
|
||||||
|
// thread creation and pinning (about 60 ms) by reusing the main pool.
|
||||||
|
if (outer_workers <= 1) {
|
||||||
|
// Still allocate a dummy pool to simplify Prefill().
|
||||||
|
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
|
||||||
|
inner_pools_.push_back(main_pool_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Before creating new threads, stop the old ones from spinning. Caller is
|
||||||
|
// responsible for undoing this by calling `ResumeMainSpinning`.
|
||||||
|
main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
|
outer_pool_ = std::make_unique<hwy::ThreadPool>(outer_workers);
|
||||||
|
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
|
// Assign up to `max_workers` to inner pools. Each inner pool creates
|
||||||
|
// `workers_per_outer - 1` threads in addition to its 'main' thread, which
|
||||||
|
// is the one calling `inner_pools[outer]->Run`, i.e., `outer`. In total,
|
||||||
|
// `outer_workers * (max_workers / outer_workers)` workers are used.
|
||||||
|
const size_t workers_per_outer = main_pool_->NumWorkers() / outer_workers;
|
||||||
|
for (size_t outer = 0; outer < outer_workers; ++outer) {
|
||||||
|
inner_pools_.push_back(new hwy::ThreadPool(workers_per_outer));
|
||||||
|
inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
}
|
||||||
|
|
||||||
|
PinThreads(outer_workers, workers_per_outer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// `tokens` are from interleaved queries. (See InterleaveQueries() below.)
|
||||||
|
template <class TConfig>
|
||||||
|
HWY_NOINLINE void Prefill(hwy::Span<const int> tokens, size_t num_queries,
|
||||||
|
size_t pos,
|
||||||
|
const CompressedWeights<TConfig>& weights,
|
||||||
|
const RuntimeConfig& runtime_config,
|
||||||
|
const std::vector<KVCache*>& kv_caches) {
|
||||||
|
PROFILER_ZONE("Gen.Prefill");
|
||||||
|
|
||||||
|
HWY_ASSERT(activations_.size() == outer_pool_->NumWorkers());
|
||||||
|
HWY_ASSERT(inner_pools_.size() == outer_pool_->NumWorkers());
|
||||||
|
|
||||||
|
outer_pool_->Run(
|
||||||
|
0, num_batches_, [&](const uint64_t batch_num, size_t thread) HWY_ATTR {
|
||||||
|
const size_t batch_start = batch_num * kPrefillBatchSize;
|
||||||
|
const size_t batch_size =
|
||||||
|
HWY_MIN(kPrefillBatchSize, tokens.size() - batch_start);
|
||||||
|
HWY_DASSERT(batch_start % num_queries == 0);
|
||||||
|
HWY_DASSERT(batch_size % num_queries == 0);
|
||||||
|
const size_t pos_per_query = pos + batch_start / num_queries;
|
||||||
|
const size_t num_tokens = batch_size / num_queries;
|
||||||
|
|
||||||
|
// Negligible time compared to TransformerLayer.
|
||||||
|
for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||||
|
EmbedToken<TConfig>(tokens[batch_start + batch_idx], batch_idx,
|
||||||
|
pos_per_query, weights, activations_[thread]);
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
TransformerLayer<TConfig, kQueryBatchSize>(
|
TransformerLayer<TConfig>(
|
||||||
current_token_count, num_queries, pos + offset, layer, layer_weights,
|
num_tokens, num_queries, pos_per_query, layer, layer_weights,
|
||||||
activations, kv_caches, pool);
|
activations_[thread], kv_caches, *inner_pools_[thread]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the transformer for a batch of input tokens. During generation,
|
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||||
// we usually have num_tokens == 1 (and also kBatchSize == 1).
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
template <class TConfig, size_t kQueryBatchSize>
|
const size_t query_idx = i % num_queries;
|
||||||
|
const size_t batch_idx = i / num_queries;
|
||||||
|
runtime_config.StreamToken(query_idx, pos_per_query + batch_idx,
|
||||||
|
tokens[i], 0.0f);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stops spinning in our pools and resume spinning in main_pool_.
|
||||||
|
void ResumeMainSpinning() {
|
||||||
|
// If we didn't create a new inner pool, we didn't stop spinning on the
|
||||||
|
// main pool, so nothing to do here.
|
||||||
|
if (inner_pools_[0] == main_pool_) return;
|
||||||
|
|
||||||
|
for (hwy::ThreadPool* p : inner_pools_) {
|
||||||
|
p->SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
|
}
|
||||||
|
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
|
main_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Pins each outer thread after their inner threads so they are likely to
|
||||||
|
// run on the same socket.
|
||||||
|
void PinThreads(size_t outer_workers, size_t workers_per_outer) {
|
||||||
|
outer_pool_->Run(
|
||||||
|
0, outer_workers,
|
||||||
|
[this, workers_per_outer](uint64_t outer, size_t outer_thread) {
|
||||||
|
HWY_ASSERT(outer == outer_thread);
|
||||||
|
// Pins inner *and* `outer` - the latter is the calling thread.
|
||||||
|
inner_pools_[outer]->Run(
|
||||||
|
0, workers_per_outer,
|
||||||
|
[outer, workers_per_outer](uint64_t task, size_t thread) {
|
||||||
|
HWY_ASSERT(task == thread); // each worker has one task
|
||||||
|
const size_t lp = outer * workers_per_outer + task;
|
||||||
|
hwy::PinThreadToLogicalProcessor(lp);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeleteInnerPools() {
|
||||||
|
for (hwy::ThreadPool* p : inner_pools_) {
|
||||||
|
if (p != main_pool_) delete p;
|
||||||
|
}
|
||||||
|
inner_pools_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
hwy::ThreadPool* main_pool_;
|
||||||
|
std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated
|
||||||
|
std::vector<Activations> activations_; // size == outer_pool->NumWorkers()
|
||||||
|
// Either there is a single pointer equal to main_pool, or newly created pools
|
||||||
|
// that we own. The former case avoids thread creation overhead for prompts
|
||||||
|
// that fit in a single batch.
|
||||||
|
std::vector<hwy::ThreadPool*> inner_pools_;
|
||||||
|
size_t num_batches_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
|
||||||
|
// `num_tokens == 1`.
|
||||||
|
template <class TConfig>
|
||||||
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
size_t num_queries, size_t pos,
|
size_t num_queries, size_t pos,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
|
|
@ -550,9 +685,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
||||||
TransformerLayer<TConfig, kQueryBatchSize>(num_tokens, num_queries, pos,
|
TransformerLayer<TConfig>(num_tokens, num_queries, pos, layer,
|
||||||
layer, layer_weights,
|
layer_weights, activations, kv_caches, pool);
|
||||||
activations, kv_caches, pool);
|
|
||||||
|
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
const std::string block_name = "blocks." + std::to_string(layer);
|
const std::string block_name = "blocks." + std::to_string(layer);
|
||||||
|
|
@ -610,42 +744,81 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
|
|
||||||
// Placeholder for internal test3, do not remove
|
// Placeholder for internal test3, do not remove
|
||||||
|
|
||||||
template <class TConfig, size_t kQueryBatchSize>
|
// Returns interleaved tokens: one from each query, followed by the second from
|
||||||
void GenerateT(const ByteStorageT& weights_u8, Activations& prefill,
|
// all queries, with EOS padding.
|
||||||
Activations& activations, const RuntimeConfig& runtime_config,
|
static std::vector<int> InterleaveQueries(
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
const hwy::Span<const hwy::Span<int>>& queries,
|
||||||
const size_t query_index_offset,
|
const RuntimeConfig& runtime_config, size_t& min_prompt_size,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
size_t& max_prompt_size) {
|
||||||
TimingInfo& timing_info) {
|
const size_t num_queries = queries.size();
|
||||||
constexpr size_t kAdjustedPrefillBatchSize =
|
min_prompt_size = hwy::LimitsMax<size_t>();
|
||||||
std::max((size_t)1, kPrefillBatchSize / kQueryBatchSize);
|
max_prompt_size = 0;
|
||||||
static_assert(kAdjustedPrefillBatchSize >= kMinAdjustedPrefillBatchSize);
|
for (size_t i = 0; i < num_queries; ++i) {
|
||||||
const size_t num_queries = prompts.size();
|
min_prompt_size = std::min(min_prompt_size, queries[i].size());
|
||||||
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
max_prompt_size = std::max(max_prompt_size, queries[i].size());
|
||||||
pos *= num_queries; // position in (num_queries) interleaved token sequence.
|
|
||||||
const CompressedWeights<TConfig>& weights =
|
|
||||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
|
||||||
|
|
||||||
size_t min_prompt_size = (size_t)-1;
|
|
||||||
size_t max_prompt_size = 0;
|
|
||||||
for (int i=0; i < prompts.size(); ++i) {
|
|
||||||
min_prompt_size = std::min(min_prompt_size, prompts[i].size());
|
|
||||||
max_prompt_size = std::max(max_prompt_size, prompts[i].size());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
prompt.reserve(max_prompt_size * prompts.size());
|
prompt.reserve(max_prompt_size * num_queries);
|
||||||
for (int i = 0; i < max_prompt_size; ++i) {
|
for (size_t pos = 0; pos < max_prompt_size; ++pos) {
|
||||||
for (int j=0; j < prompts.size(); ++j) {
|
for (size_t q = 0; q < num_queries; ++q) {
|
||||||
if (i < prompts[j].size()) {
|
if (pos < queries[q].size()) {
|
||||||
prompt.push_back(prompts[j][i]);
|
prompt.push_back(queries[q][pos]);
|
||||||
} else {
|
} else {
|
||||||
prompt.push_back(0);
|
prompt.push_back(runtime_config.eos_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Holds "is at end of stream" state for each query.
|
||||||
|
class TokenStreamer {
|
||||||
|
public:
|
||||||
|
explicit TokenStreamer(const RuntimeConfig& runtime_config)
|
||||||
|
: runtime_config_(runtime_config) {}
|
||||||
|
|
||||||
|
// Returns whether the query was already at, or has just reached, the end of
|
||||||
|
// the stream: either via token == eos_id, or StreamToken returning false.
|
||||||
|
bool operator()(size_t query_idx, size_t pos, int token, float prob) {
|
||||||
|
if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true;
|
||||||
|
|
||||||
|
if (!runtime_config_.StreamToken(query_idx, pos, token, prob) ||
|
||||||
|
token == runtime_config_.eos_id) {
|
||||||
|
is_eos_.Set(query_idx);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const RuntimeConfig& runtime_config_;
|
||||||
|
// BitSet4096 divides the arg by 64, so ensure it is at least 64.
|
||||||
|
hwy::BitSet4096<HWY_MAX(64, kBatchedQueryBatchSize)> is_eos_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generates one token per query in the batch.
|
||||||
|
//
|
||||||
|
// pos indexes the KV cache. In the first turn of a chat, pos = 0, and it
|
||||||
|
// continues to increase by one for each prefilled/generated token per query.
|
||||||
|
// query_idx_start is the first query index in the batch.
|
||||||
|
template <class TConfig, size_t kQueryBatchSize>
|
||||||
|
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
|
const RuntimeConfig& runtime_config,
|
||||||
|
const hwy::Span<const hwy::Span<int>>& prompts, const size_t pos,
|
||||||
|
const size_t query_idx_start,
|
||||||
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||||
|
TimingInfo& timing_info) {
|
||||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
|
const CompressedWeights<TConfig>& weights =
|
||||||
|
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
|
|
||||||
|
const size_t num_queries = prompts.size();
|
||||||
|
HWY_DASSERT(num_queries <= kQueryBatchSize);
|
||||||
|
size_t min_prompt_size, max_prompt_size;
|
||||||
|
const std::vector<int> prompt = InterleaveQueries(
|
||||||
|
prompts, runtime_config, min_prompt_size, max_prompt_size);
|
||||||
|
|
||||||
size_t max_tokens = runtime_config.max_tokens;
|
size_t max_tokens = runtime_config.max_tokens;
|
||||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||||
|
|
@ -666,171 +839,92 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
runtime_config.accept_token);
|
runtime_config.accept_token);
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<bool> reached_eos(num_queries);
|
// Prefill stops before min_prompt_size - 1 because the last prompt token is
|
||||||
std::fill(reached_eos.begin(), reached_eos.end(), false);
|
// the first input token for generation.
|
||||||
|
const size_t prefill_per_query = min_prompt_size - 1;
|
||||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
const hwy::Span<const int> prefill_tokens(prompt.data(),
|
||||||
//
|
prefill_per_query * num_queries);
|
||||||
// After the first turn, pos gets passed in with > 0 corresponding to the
|
PrefillState prefill(pool);
|
||||||
// current token position in the KV cache.
|
prefill.Init<TConfig>(prefill_tokens.size());
|
||||||
//
|
|
||||||
// pos_offset keeps track of the relative position within the turn, starting
|
|
||||||
// at 0 each turn. During prefill, pos_offset corresponds to the index into
|
|
||||||
// the prompt vector.
|
|
||||||
//
|
|
||||||
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
|
|
||||||
// always equal.
|
|
||||||
size_t pos_offset = 0; // offset relative to pos
|
|
||||||
// Used to keep track of how many tokens are processed per prompt,
|
|
||||||
// so that we know when to start generating tokens.
|
|
||||||
size_t single_prompt_pos_offset = 0;
|
|
||||||
const double prefill_start = hwy::platform::Now();
|
const double prefill_start = hwy::platform::Now();
|
||||||
|
size_t interleaved_pos = pos * num_queries;
|
||||||
|
prefill.Prefill<TConfig>(prefill_tokens, num_queries, interleaved_pos,
|
||||||
|
weights, runtime_config, kv_caches);
|
||||||
|
interleaved_pos += prefill_tokens.size();
|
||||||
|
timing_info.NotifyPrefill(prefill_tokens.size(), prefill_start);
|
||||||
|
|
||||||
// Prefill stops before prompt_size - 1 since the last prompt token is the
|
prefill.ResumeMainSpinning();
|
||||||
// first input token for generation.
|
|
||||||
while (single_prompt_pos_offset < min_prompt_size - 1) {
|
// Storage for the last generated token from each query, passed to the next
|
||||||
const size_t batch_size = std::min(
|
// Transformer() call.
|
||||||
kPrefillBatchSize, min_prompt_size - 1 - single_prompt_pos_offset);
|
std::vector<int> gen_tokens(num_queries);
|
||||||
const size_t batch_and_query_size = batch_size * num_queries;
|
|
||||||
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
// Stream the last prompt token from each query and fill gen_tokens.
|
||||||
HWY_DASSERT(single_prompt_pos_offset + batch_size <= min_prompt_size - 1);
|
hwy::CopyBytes(&prompt[prefill_tokens.size()], gen_tokens.data(),
|
||||||
HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries);
|
num_queries * sizeof(prompt[0]));
|
||||||
const int* batch_tokens = prompt.data() + pos_offset;
|
TokenStreamer token_streamer(runtime_config);
|
||||||
Prefill<TConfig, kAdjustedPrefillBatchSize, kQueryBatchSize>(
|
|
||||||
batch_tokens, batch_size, num_queries, pos, weights, prefill, kv_caches,
|
|
||||||
pool);
|
|
||||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
|
||||||
bool all_tokens_eos = true;
|
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
if (reached_eos[query_idx]) continue;
|
(void)token_streamer(query_idx_start + query_idx, prefill_per_query,
|
||||||
if (runtime_config.StreamToken(
|
gen_tokens[query_idx], 0.0f);
|
||||||
query_idx + query_index_offset, single_prompt_pos_offset,
|
|
||||||
batch_tokens[idx * num_queries + query_idx], 0.0f)) {
|
|
||||||
all_tokens_eos = false;
|
|
||||||
} else {
|
|
||||||
reached_eos[query_idx] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (all_tokens_eos) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pos += batch_and_query_size;
|
|
||||||
pos_offset += batch_and_query_size;
|
|
||||||
single_prompt_pos_offset += batch_size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
timing_info.prefill_tok_sec =
|
|
||||||
static_cast<double>(pos_offset) / (hwy::platform::Now() - prefill_start);
|
|
||||||
|
|
||||||
// Start generation.
|
|
||||||
const double gen_start = hwy::platform::Now();
|
const double gen_start = hwy::platform::Now();
|
||||||
HWY_DASSERT(single_prompt_pos_offset == min_prompt_size - 1);
|
for (size_t gen_per_query = 0;
|
||||||
size_t pos_gen_start = pos_offset;
|
gen_per_query < HWY_MIN(max_tokens, max_generated_tokens);
|
||||||
int token = prompt.at(pos_offset);
|
++gen_per_query) {
|
||||||
std::vector<int>::const_iterator first = prompt.begin() + pos_offset;
|
// Decode: generate one token for each query.
|
||||||
std::vector<int>::const_iterator last = first + num_queries;
|
Transformer<TConfig>(gen_tokens.data(), /*num_tokens=*/1, num_queries,
|
||||||
std::vector<int> gen_tokens(first, last);
|
interleaved_pos, weights, activations, kv_caches, pool,
|
||||||
// The loop below is not yet prepared for decode batch size > 1.
|
runtime_config.layers_output);
|
||||||
HWY_ASSERT(kDecodeBatchSize == 1);
|
interleaved_pos += num_queries;
|
||||||
bool all_tokens_eos = true;
|
|
||||||
for (size_t i=0; i < num_queries; ++i) {
|
bool all_queries_eos = true;
|
||||||
if (reached_eos[i]) continue;
|
|
||||||
if (runtime_config.StreamToken(i + query_index_offset,
|
|
||||||
single_prompt_pos_offset, gen_tokens[i],
|
|
||||||
0.0f)) {
|
|
||||||
all_tokens_eos = false;
|
|
||||||
} else {
|
|
||||||
reached_eos[i] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (all_tokens_eos) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (size_t generate_pos = 0;
|
|
||||||
generate_pos < max_tokens && generate_pos < max_generated_tokens;
|
|
||||||
++single_prompt_pos_offset, ++generate_pos) {
|
|
||||||
Transformer<TConfig, kQueryBatchSize>(
|
|
||||||
gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights,
|
|
||||||
activations, kv_caches, pool, runtime_config.layers_output);
|
|
||||||
float token_logit = 0.0f;
|
|
||||||
// The condition below is always true if we are doing Prefill above.
|
|
||||||
// We keep it here for clarity so that the code is correct even if Prefill
|
|
||||||
// is disabled.
|
|
||||||
bool all_tokens_eos = true;
|
|
||||||
for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset) {
|
|
||||||
const float* HWY_RESTRICT x = activations.x.Batch(i);
|
|
||||||
float* HWY_RESTRICT logits = activations.logits.Batch(i);
|
|
||||||
const size_t prompt_size = prompts[i].size();
|
|
||||||
const bool is_generating_phase =
|
|
||||||
(single_prompt_pos_offset >= prompt_size - 1);
|
|
||||||
if (is_generating_phase) {
|
|
||||||
PROFILER_ZONE("Gen.Embedding");
|
PROFILER_ZONE("Gen.Embedding");
|
||||||
// Compute logits from last layer activations.
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
|
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||||
0, x, activations.even_odd.All(),
|
// Compute logits from last layer activations. TODO: MatMul
|
||||||
logits, pool);
|
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||||
|
weights.embedder_input_embedding, 0, activations.x.Batch(query_idx),
|
||||||
|
activations.even_odd.All(), logits, pool);
|
||||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||||
LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
||||||
}
|
}
|
||||||
// Barrier: must have all logits so we can subtract max.
|
|
||||||
Softmax(logits, kVocabSize);
|
Softmax(logits, kVocabSize);
|
||||||
token = sample_token(logits, kVocabSize);
|
const int token = sample_token(logits, kVocabSize);
|
||||||
token_logit = logits[token];
|
timing_info.NotifyGenerated(prefill_start);
|
||||||
if (generate_pos == 0) {
|
|
||||||
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// We would take this branch if we were not doing Prefill but would
|
|
||||||
// process the tokens of the prompt one at a time.
|
|
||||||
token = prompt.at(pos_offset);
|
|
||||||
token_logit = 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!reached_eos[i]) {
|
const bool is_eos = token_streamer(query_idx_start + query_idx,
|
||||||
if (!runtime_config.StreamToken(i + query_index_offset,
|
prefill_per_query + 1 + gen_per_query,
|
||||||
single_prompt_pos_offset + 1, token,
|
token, logits[token]);
|
||||||
token_logit)) {
|
all_queries_eos &= is_eos;
|
||||||
token = runtime_config.eos_id;
|
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;
|
||||||
}
|
}
|
||||||
if (token != runtime_config.eos_id) {
|
if (all_queries_eos) break;
|
||||||
all_tokens_eos = false;
|
} // foreach token to generate
|
||||||
} else {
|
|
||||||
reached_eos[i] = true;
|
timing_info.NotifyGenerateDone(gen_start);
|
||||||
}
|
|
||||||
}
|
|
||||||
gen_tokens[i] = token;
|
|
||||||
}
|
|
||||||
if (all_tokens_eos) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
timing_info.gen_tok_sec = static_cast<double>(pos_offset - pos_gen_start) /
|
|
||||||
(hwy::platform::Now() - gen_start);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: prompt should also be span, not a vector.
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void GenerateSingleT(const ByteStorageT& weights_u8, Activations& prefill,
|
void GenerateSingleT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
Activations& activations,
|
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t pos,
|
const std::vector<int>& prompt, size_t pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
// TODO: the input should also be span, not a vector.
|
|
||||||
const hwy::Span<int> prompt_span(const_cast<int*>(prompt.data()),
|
const hwy::Span<int> prompt_span(const_cast<int*>(prompt.data()),
|
||||||
prompt.size());
|
prompt.size());
|
||||||
const hwy::Span<const hwy::Span<int>> prompts(&prompt_span, 1);
|
const hwy::Span<const hwy::Span<int>> prompts(&prompt_span, 1);
|
||||||
// TODO: also span of kv_cache.
|
// TODO: also span of kv_cache, or batching inside KVCache?
|
||||||
std::vector<KVCache*> kv_caches = {&kv_cache};
|
std::vector<KVCache*> kv_caches = {&kv_cache};
|
||||||
const size_t query_index_offset = 0;
|
const size_t query_idx_start = 0;
|
||||||
GenerateT<TConfig, /*kQueryBatchSize=*/1>(
|
GenerateT<TConfig, /*kQueryBatchSize=*/1>(
|
||||||
weights_u8, prefill, activations, runtime_config, prompts, pos,
|
weights_u8, activations, runtime_config, prompts, pos, query_idx_start,
|
||||||
query_index_offset, kv_caches, pool, timing_info);
|
kv_caches, pool, timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill,
|
void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
Activations& activations,
|
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||||
const std::vector<KVCache*>& kv_caches,
|
const std::vector<KVCache*>& kv_caches,
|
||||||
|
|
@ -838,12 +932,14 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
// Disable query batching for Griffin models.
|
// Disable query batching for Griffin models.
|
||||||
constexpr size_t kQueryBatchSize =
|
constexpr size_t kQueryBatchSize =
|
||||||
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize;
|
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize;
|
||||||
for (size_t i = 0; i < prompts.size(); i += kQueryBatchSize) {
|
for (size_t query_idx_start = 0; query_idx_start < prompts.size();
|
||||||
const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize);
|
query_idx_start += kQueryBatchSize) {
|
||||||
const hwy::Span<const hwy::Span<int>> current_prompts(
|
const size_t num_queries =
|
||||||
prompts.data() + i, num_queries);
|
std::min(prompts.size() - query_idx_start, kQueryBatchSize);
|
||||||
GenerateT<TConfig, kQueryBatchSize>(weights_u8, prefill, activations,
|
const hwy::Span<const hwy::Span<int>> query_batch(
|
||||||
runtime_config, current_prompts, pos, i,
|
prompts.data() + query_idx_start, num_queries);
|
||||||
|
GenerateT<TConfig, kQueryBatchSize>(weights_u8, activations, runtime_config,
|
||||||
|
query_batch, pos, query_idx_start,
|
||||||
kv_caches, pool, timing_info);
|
kv_caches, pool, timing_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -855,24 +951,24 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill,
|
||||||
// These are extern functions defined by instantiations/*.cc, which include this
|
// These are extern functions defined by instantiations/*.cc, which include this
|
||||||
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
|
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
|
||||||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||||
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill,
|
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations,
|
||||||
Activations& activations, const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config, const std::vector<int>& prompt,
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
|
||||||
(weights_u8, prefill, activations, runtime_config, prompt, pos, kv_cache,
|
(weights_u8, activations, runtime_config, prompt, pos, kv_cache, pool,
|
||||||
pool, timing_info);
|
timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||||
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill,
|
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations,
|
||||||
Activations& activations, const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
||||||
(weights_u8, prefill, activations, runtime_config, prompts, pos, kv_caches,
|
(weights_u8, activations, runtime_config, prompts, pos, kv_caches, pool,
|
||||||
pool, timing_info);
|
timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
|
|
@ -37,13 +37,11 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
struct AllocateState {
|
struct AllocateActivations {
|
||||||
void operator()(Activations& prefill, Activations& decode) const {
|
void operator()(Activations& decode) const {
|
||||||
// When batching queries, the prefill batch size is reduced by a factor
|
// TODO: this is wasted if we only have single-batch queries. Instead
|
||||||
// of kBatchedQueryBatchSize
|
// re-allocate when the query batch size is actually > 1?
|
||||||
prefill.Allocate<TConfig>(kMinAdjustedPrefillBatchSize *
|
decode.Allocate<TConfig>(kBatchedQueryBatchSize);
|
||||||
kBatchedQueryBatchSize);
|
|
||||||
decode.Allocate<TConfig>(kDecodeBatchSize * kBatchedQueryBatchSize);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -51,8 +49,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||||
const ModelInfo& info, hwy::ThreadPool& pool)
|
const ModelInfo& info, hwy::ThreadPool& pool)
|
||||||
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
||||||
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
|
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
|
||||||
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_,
|
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
|
||||||
decode_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||||
|
|
@ -61,8 +58,7 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||||
HWY_ASSERT(info.weight == Type::kF32);
|
HWY_ASSERT(info.weight == Type::kF32);
|
||||||
weights_u8_ =
|
weights_u8_ =
|
||||||
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
|
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
|
||||||
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_,
|
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
|
||||||
decode_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::~Gemma() {
|
Gemma::~Gemma() {
|
||||||
|
|
@ -76,13 +72,13 @@ Gemma::~Gemma() {
|
||||||
// explicit instantiations are still too slow to compile.
|
// explicit instantiations are still too slow to compile.
|
||||||
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
|
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
|
||||||
extern void GenerateSingle( \
|
extern void GenerateSingle( \
|
||||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& prefill, \
|
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \
|
||||||
Activations& decode, const RuntimeConfig& runtime_config, \
|
const RuntimeConfig& runtime_config, const std::vector<int>& prompt, \
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, \
|
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info); \
|
TimingInfo& timing_info); \
|
||||||
extern void GenerateBatch( \
|
extern void GenerateBatch( \
|
||||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& prefill, \
|
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \
|
||||||
Activations& decode, const RuntimeConfig& runtime_config, \
|
const RuntimeConfig& runtime_config, \
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \
|
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \
|
||||||
TimingInfo& timing_info);
|
TimingInfo& timing_info);
|
||||||
|
|
@ -92,24 +88,24 @@ GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
||||||
// TODO: gather all ByteStorageT into a type-erased model struct?
|
// TODO: gather all ByteStorageT into a type-erased model struct?
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct GenerateSingleT {
|
struct GenerateSingleT {
|
||||||
void operator()(const ByteStorageT& weights_u8, Activations& prefill,
|
void operator()(const ByteStorageT& weights_u8, Activations& decode,
|
||||||
Activations& decode, const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
|
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
|
||||||
GenerateSingle(TConfig(), weights_u8, prefill, decode, runtime_config,
|
GenerateSingle(TConfig(), weights_u8, decode, runtime_config, prompt, pos,
|
||||||
prompt, pos, kv_cache, pool, timing_info);
|
kv_cache, pool, timing_info);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct GenerateBatchT {
|
struct GenerateBatchT {
|
||||||
void operator()(const ByteStorageT& weights_u8, Activations& prefill,
|
void operator()(const ByteStorageT& weights_u8, Activations& decode,
|
||||||
Activations& decode, const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
|
||||||
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
|
||||||
TimingInfo& timing_info) const {
|
TimingInfo& timing_info) const {
|
||||||
GenerateBatch(TConfig(), weights_u8, prefill, decode, runtime_config,
|
GenerateBatch(TConfig(), weights_u8, decode, runtime_config, prompts, pos,
|
||||||
prompts, pos, kv_caches, pool, timing_info);
|
kv_caches, pool, timing_info);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -119,8 +115,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
CallForModelAndWeight<GenerateSingleT>(
|
CallForModelAndWeight<GenerateSingleT>(
|
||||||
info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config,
|
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompt,
|
||||||
prompt, start_pos, kv_cache, pool_, timing_info);
|
start_pos, kv_cache, pool_, timing_info);
|
||||||
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
@ -133,8 +129,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
CallForModelAndWeight<GenerateBatchT>(
|
CallForModelAndWeight<GenerateBatchT>(
|
||||||
info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config,
|
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompts,
|
||||||
prompts, start_pos, kv_caches, pool_, timing_info);
|
start_pos, kv_caches, pool_, timing_info);
|
||||||
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
#include "gemma/kv_cache.h"
|
#include "gemma/kv_cache.h"
|
||||||
#include "gemma/tokenizer.h"
|
#include "gemma/tokenizer.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/timer.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
@ -78,9 +79,30 @@ struct RuntimeConfig {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TimingInfo {
|
struct TimingInfo {
|
||||||
double prefill_tok_sec = 0.0;
|
void NotifyPrefill(size_t tokens, double start) {
|
||||||
double gen_tok_sec = 0.0;
|
prefill_tok_sec =
|
||||||
double time_to_first_token = 0.0;
|
static_cast<double>(tokens) / (hwy::platform::Now() - start);
|
||||||
|
gen_tok_sec = 0.0;
|
||||||
|
time_to_first_token = 0.0;
|
||||||
|
tokens_generated = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NotifyGenerated(double prefill_start) {
|
||||||
|
++tokens_generated;
|
||||||
|
if (HWY_UNLIKELY(tokens_generated == 1)) {
|
||||||
|
time_to_first_token = hwy::platform::Now() - prefill_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NotifyGenerateDone(double gen_start) {
|
||||||
|
gen_tok_sec = static_cast<double>(tokens_generated) /
|
||||||
|
(hwy::platform::Now() - gen_start);
|
||||||
|
}
|
||||||
|
|
||||||
|
double prefill_tok_sec;
|
||||||
|
double gen_tok_sec;
|
||||||
|
double time_to_first_token;
|
||||||
|
size_t tokens_generated;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Gemma {
|
class Gemma {
|
||||||
|
|
@ -96,7 +118,6 @@ class Gemma {
|
||||||
const ModelInfo& Info() const { return info_; }
|
const ModelInfo& Info() const { return info_; }
|
||||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||||
const Activations& Prefill() const { return prefill_; }
|
|
||||||
const Activations& Decode() const { return decode_; }
|
const Activations& Decode() const { return decode_; }
|
||||||
|
|
||||||
void Generate(const RuntimeConfig& runtime_config,
|
void Generate(const RuntimeConfig& runtime_config,
|
||||||
|
|
@ -115,7 +136,6 @@ class Gemma {
|
||||||
// Type-erased so that this can be defined in the header, without requiring
|
// Type-erased so that this can be defined in the header, without requiring
|
||||||
// forwarding functions.
|
// forwarding functions.
|
||||||
ByteStorageT weights_u8_;
|
ByteStorageT weights_u8_;
|
||||||
Activations prefill_;
|
|
||||||
Activations decode_;
|
Activations decode_;
|
||||||
ModelInfo info_;
|
ModelInfo info_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -97,9 +97,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
} else if (token == EOS_ID) {
|
} else if (token == EOS_ID) {
|
||||||
if (!args.multiturn) {
|
if (!args.multiturn) {
|
||||||
abs_pos = 0;
|
abs_pos = 0;
|
||||||
if (args.deterministic) {
|
InitGenerator(args, gen);
|
||||||
gen.seed(42);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (verbosity >= 2) {
|
if (verbosity >= 2) {
|
||||||
std::cout << "\n[ End ]\n";
|
std::cout << "\n[ End ]\n";
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue